-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathtest_distributed.py
96 lines (78 loc) · 3.93 KB
/
test_distributed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import logging
import os
from datetime import timedelta
import pytest
from sagemaker.pytorch import PyTorch
import sagemaker_ssh_helper
from sagemaker_ssh_helper.wrapper import SSHEstimatorWrapper
def test_node_rank_from_env_json():
os.environ["SAGEMAKER_BASE_DIR"] = os.path.join(os.path.dirname(__file__), "opt_ml")
node_rank = sagemaker_ssh_helper.env.sm_get_node_rank()
assert node_rank == 0
def test_node_rank_from_env_json_non_existing_rc():
os.environ["SAGEMAKER_BASE_DIR"] = os.path.join(os.path.dirname(__file__), "opt_ml_non_existing")
node_rank = sagemaker_ssh_helper.env.sm_get_node_rank()
assert node_rank == 0
def test_distributed_training_with_default_instance_count():
instance_count = 3
default_ssh_instance_count = 2
estimator = PyTorch(entry_point='train.py',
source_dir='source_dir/training/',
dependencies=[SSHEstimatorWrapper.dependency_dir()],
base_job_name='ssh-training',
framework_version='1.9.1',
py_version='py38',
instance_count=instance_count,
instance_type='ml.m5.xlarge',
max_run=int(timedelta(minutes=15).total_seconds()),
keep_alive_period_in_seconds=1800,
container_log_level=logging.INFO)
ssh_wrapper = SSHEstimatorWrapper.create(estimator, connection_wait_time_seconds=600)
estimator.fit(wait=False)
mi_ids = ssh_wrapper.get_instance_ids(timeout_in_sec=600)
ssh_wrapper.stop_training_job()
assert len(mi_ids) == default_ssh_instance_count
@pytest.mark.parametrize("ssh_instance_count", [3, 1])
def test_distributed_training_with_changed_instance_count(ssh_instance_count):
instance_count = 3
estimator = PyTorch(entry_point='train.py',
source_dir='source_dir/training/',
dependencies=[SSHEstimatorWrapper.dependency_dir()],
base_job_name='ssh-training',
framework_version='1.9.1',
py_version='py38',
instance_count=instance_count,
instance_type='ml.m5.xlarge',
max_run=int(timedelta(minutes=15).total_seconds()),
keep_alive_period_in_seconds=1800,
container_log_level=logging.INFO)
ssh_wrapper = SSHEstimatorWrapper.create(estimator, connection_wait_time_seconds=600,
ssh_instance_count=ssh_instance_count)
estimator.fit(wait=False)
mi_ids = ssh_wrapper.get_instance_ids(timeout_in_sec=600)
ssh_wrapper.stop_training_job()
assert len(mi_ids) == ssh_instance_count
def test_distributed_training_mpi_single_node():
instance_count = 1
estimator = PyTorch(entry_point='train.py',
source_dir='source_dir/training/',
dependencies=[SSHEstimatorWrapper.dependency_dir()],
base_job_name='ssh-training',
framework_version='1.9.1',
py_version='py38',
instance_count=instance_count,
instance_type='ml.g4dn.xlarge',
max_run=int(timedelta(minutes=15).total_seconds()),
keep_alive_period_in_seconds=1800,
container_log_level=logging.INFO,
distribution={
'mpi': {
'enabled': True,
'processes_per_host': 4,
}
})
ssh_wrapper = SSHEstimatorWrapper.create(estimator, connection_wait_time_seconds=600)
estimator.fit(wait=False)
mi_ids = ssh_wrapper.get_instance_ids(timeout_in_sec=600)
ssh_wrapper.stop_training_job()
assert len(mi_ids) == 1