-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathtest_hpo.py
105 lines (81 loc) · 3.58 KB
/
test_hpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import logging
import os
import time
from datetime import timedelta
import pytest
from sagemaker.mxnet import MXNet
from sagemaker.parameter import CategoricalParameter, ContinuousParameter, IntegerParameter
from sagemaker.tuner import HyperparameterTuner
from sagemaker_ssh_helper.log import SSHLog
from sagemaker_ssh_helper.manager import SSMManager
from sagemaker_ssh_helper.wrapper import SSHEstimatorWrapper
def test_clean_hpo():
estimator = MXNet(entry_point=os.path.basename('source_dir/training_clean/train_clean.py'),
source_dir='source_dir/training_clean/',
dependencies=[SSHEstimatorWrapper.dependency_dir()],
py_version='py38',
framework_version='1.9',
instance_count=1,
instance_type='ml.m5.xlarge',
max_run=int(timedelta(minutes=15).total_seconds()),
container_log_level=logging.INFO)
# Adopted from https://github.com/aws/amazon-sagemaker-examples/blob/main/hyperparameter_tuning/mxnet_mnist/hpo_mxnet_mnist.ipynb
hyperparameter_ranges = {
"optimizer": CategoricalParameter(["sgd", "Adam"]),
"learning_rate": ContinuousParameter(0.01, 0.2),
"num_epoch": IntegerParameter(10, 50),
}
objective_metric_name = "model-accuracy"
metric_definitions = [{"Name": "model-accuracy", "Regex": "model-accuracy=([0-9\\.]+)"}]
tuner = HyperparameterTuner(
estimator,
objective_metric_name,
hyperparameter_ranges,
metric_definitions,
max_jobs=3,
max_parallel_jobs=2,
)
tuner.fit()
best_training_job = tuner.best_training_job()
assert best_training_job is not None
def test_hpo_ssh():
estimator = MXNet(entry_point=os.path.basename('source_dir/training/train.py'),
source_dir='source_dir/training/',
dependencies=[SSHEstimatorWrapper.dependency_dir()],
py_version='py38',
framework_version='1.9',
instance_count=1,
instance_type='ml.m5.xlarge',
max_run=int(timedelta(minutes=15).total_seconds()),
container_log_level=logging.INFO)
ssh_wrapper = SSHEstimatorWrapper.create(estimator, connection_wait_time_seconds=60)
hyperparameter_ranges = {
"optimizer": CategoricalParameter(["sgd", "Adam"]),
"learning_rate": ContinuousParameter(0.01, 0.2),
"num_epoch": IntegerParameter(10, 50),
}
objective_metric_name = "model-accuracy"
metric_definitions = [{"Name": "model-accuracy", "Regex": "model-accuracy=([0-9\\.]+)"}]
tuner = HyperparameterTuner(
estimator,
objective_metric_name,
hyperparameter_ranges,
metric_definitions,
base_tuning_job_name='ssh-hpo-mxnet',
max_jobs=3,
max_parallel_jobs=2,
)
tuner.fit(wait=False)
with pytest.raises(AssertionError):
# Shouldn't be able to get instance ids without calling estimator.fit() first
_ = ssh_wrapper.get_instance_ids(timeout_in_sec=0)
time.sleep(15) # allow training jobs to start
analytics = tuner.analytics()
training_jobs = analytics.training_job_summaries()
training_job_name = training_jobs[0]['TrainingJobName']
ssh_wrapper = SSHEstimatorWrapper.attach(training_job_name)
instance_ids = ssh_wrapper.get_instance_ids()
assert len(instance_ids) == 1
tuner.wait()
best_training_job = tuner.best_training_job()
assert best_training_job is not None