-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathtest_functions.py
261 lines (197 loc) · 10.8 KB
/
test_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import logging
import os
import subprocess
from datetime import timedelta
from pathlib import Path
import boto3
import pytest
import sagemaker.config
from sagemaker.pytorch import PyTorch
from sagemaker_ssh_helper.log import SSHLog
from sagemaker_ssh_helper.wrapper import SSHEnvironmentWrapper, SSHEstimatorWrapper, SSHModelWrapper
from test_util import _create_bucket_if_doesnt_exist
logger = logging.getLogger('sagemaker-ssh-helper:test_functions')
def test_smoke():
# Quick way to test CI/CD pipeline without running any tests.
# Pass PYTEST_EXTRA_ARGS=test_functions.py::test_smoke as GitLab variable
assert 42 == 42
def test_ssm_role_from_arn():
assert SSHEnvironmentWrapper.ssm_role_from_iam_arn("arn:aws:iam::012345678901:role/service-role/SageMakerRole") \
== 'service-role/SageMakerRole'
def test_ssm_role_from_arn_cn_us_gov():
# See: https://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html
assert SSHEnvironmentWrapper.ssm_role_from_iam_arn("arn:aws-cn:iam::012345678901:role/service-role/SageMakerRole") \
== 'service-role/SageMakerRole'
assert SSHEnvironmentWrapper.ssm_role_from_iam_arn(
"arn:aws-us-gov:iam::012345678901:role/service-role/SageMakerRole") \
== 'service-role/SageMakerRole'
def test_ssm_role_fail():
with pytest.raises(ValueError):
SSHEnvironmentWrapper.ssm_role_from_iam_arn("service-role/SageMakerRole")
def test_wrapper_checks_ssm_role_bad_prefix():
with pytest.raises(ValueError):
SSHEstimatorWrapper(
PyTorch(entry_point='', image_uri='', role='arn:aws:iam::012345678901:role/service-role/SageMakerRole',
instance_count=1, instance_type='ml.m5.large'),
ssm_iam_role='arn:aws:iam::0123456789012:role/service-role/SageMakerRole',
bootstrap_on_start=True,
connection_wait_time_seconds=3600
)
def test_wrapper_checks_ssm_role_good_prefix():
SSHEstimatorWrapper(
PyTorch(entry_point='', image_uri='', role='arn:aws:iam::012345678901:role/service-role/SageMakerRole',
instance_count=1, instance_type='ml.m5.large'),
ssm_iam_role='service-role/SageMakerRole',
bootstrap_on_start=True,
connection_wait_time_seconds=3600
)
def test_wrapper_infers_ssm_role():
wrapper = SSHEstimatorWrapper(
PyTorch(entry_point='', image_uri='', role='arn:aws:iam::012345678901:role/service-role/SageMakerRole',
instance_count=1, instance_type='ml.m5.large'),
bootstrap_on_start=True,
connection_wait_time_seconds=3600
)
assert wrapper.ssm_iam_role == 'service-role/SageMakerRole'
def test_ssm_role_from_arn_simple():
assert SSHEnvironmentWrapper.ssm_role_from_iam_arn("arn:aws:iam::012345678901:role/SageMakerRole") \
== 'SageMakerRole'
def test_ssm_role_fail_simple():
with pytest.raises(ValueError):
SSHEnvironmentWrapper.ssm_role_from_iam_arn("SageMakerRole")
def test_wrapper_checks_ssm_role_bad_prefix_simple():
with pytest.raises(ValueError):
SSHEstimatorWrapper(
PyTorch(entry_point='', image_uri='', role='arn:aws:iam::012345678901:role/SageMakerRole',
instance_count=1, instance_type='ml.m5.large'),
ssm_iam_role='arn:aws:iam::0123456789012:role/SageMakerRole',
bootstrap_on_start=True,
connection_wait_time_seconds=3600
)
def test_wrapper_checks_ssm_role_good_prefix_simple():
SSHEstimatorWrapper(
PyTorch(entry_point='', image_uri='', role='arn:aws:iam::012345678901:role/SageMakerRole',
instance_count=1, instance_type='ml.m5.large'),
ssm_iam_role='SageMakerRole',
bootstrap_on_start=True,
connection_wait_time_seconds=3600
)
def test_wrapper_infers_ssm_role_simple():
wrapper = SSHEstimatorWrapper(
PyTorch(entry_point='', image_uri='', role='arn:aws:iam::012345678901:role/SageMakerRole',
instance_count=1, instance_type='ml.m5.large'),
bootstrap_on_start=True,
connection_wait_time_seconds=3600
)
assert wrapper.ssm_iam_role == 'SageMakerRole'
@pytest.mark.skipif(os.getenv('PYTEST_IGNORE_SKIPS', "false") == "false",
reason="Not yet working")
def test_model_wrapper_infers_ssm_role_with_defaults():
from sagemaker.huggingface import HuggingFaceModel
model = HuggingFaceModel(
model_data='',
transformers_version='4.17.0',
pytorch_version='1.10.2',
py_version='py38',
dependencies=[SSHModelWrapper.dependency_dir()]
)
# TODO: This is not working yet.
ssh_wrapper = SSHModelWrapper.create(model, connection_wait_time_seconds=0)
assert ssh_wrapper.ssm_iam_role
# noinspection DuplicatedCode
def test_estimator_wrapper_infers_ssm_role_with_defaults():
estimator = PyTorch(entry_point='',
dependencies=[SSHEstimatorWrapper.dependency_dir()],
framework_version='1.9.1',
py_version='py38',
instance_count=1,
instance_type='ml.m5.xlarge',
max_run=int(timedelta(minutes=15).total_seconds()),
keep_alive_period_in_seconds=1800,
container_log_level=logging.INFO)
ssh_wrapper = SSHEstimatorWrapper.create(estimator, connection_wait_time_seconds=600)
assert ssh_wrapper.ssm_iam_role
def test_bucket_exists():
account_id = boto3.client('sts').get_caller_identity().get('Account')
custom_bucket_name = f'sagemaker-test-bucket-{account_id}'
_ = _create_bucket_if_doesnt_exist('eu-west-1', custom_bucket_name)
bucket = _create_bucket_if_doesnt_exist('eu-west-1', custom_bucket_name)
bucket.delete()
def test_sagemaker_default_config_location():
f"""
See: https://sagemaker.readthedocs.io/en/stable/overview.html#default-configuration-file-location
See: {sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA}
"""
import os
from platformdirs import site_config_dir, user_config_dir
# Prints the location of the admin config file
logging.info(os.path.join(site_config_dir("sagemaker"), "config.yaml"))
# Prints the location of the user config file
logging.info(os.path.join(user_config_dir("sagemaker"), "config.yaml"))
def test_dirname():
assert os.path.dirname('source_dir/training/train.py') == 'source_dir/training'
def test_cloud_watch_url_training():
url = SSHLog().get_training_cloudwatch_url('ssh-training-2023-04-20-17-03-10-793')
logging.info(url)
assert url == "https://eu-west-1.console.aws.amazon.com/cloudwatch/home?" \
"region=eu-west-1#logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252FTrainingJobs$3F" \
"logStreamNameFilter$3Dssh-training-2023-04-20-17-03-10-793$252F"
def test_cloud_watch_url_training_china():
url = SSHLog(region_name="cn-north-1").get_training_cloudwatch_url('ssh-training-sklearn-2023-02-20-22-34-59-078')
logging.info(url)
assert url == "https://cn-north-1.console.amazonaws.cn/cloudwatch/home?" \
"region=cn-north-1#logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252FTrainingJobs$3F" \
"logStreamNameFilter$3Dssh-training-sklearn-2023-02-20-22-34-59-078$252F"
def test_cloud_watch_url_training_us_gov():
url = SSHLog(region_name="us-gov-west-1").get_training_cloudwatch_url('ssh-training-sklearn-2023-02-20-22-34-59-078')
logging.info(url)
assert url == "https://us-gov-west-1.console.amazonaws-us-gov.com/cloudwatch/home?" \
"region=us-gov-west-1#logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252FTrainingJobs$3F" \
"logStreamNameFilter$3Dssh-training-sklearn-2023-02-20-22-34-59-078$252F"
def test_cloud_watch_url_endpoint():
url = SSHLog().get_endpoint_cloudwatch_url('ssh-inference-tf-2023-04-21-09-07-10-172')
logging.info(url)
assert url == "https://eu-west-1.console.aws.amazon.com/cloudwatch/home?region=eu-west-1#" \
"logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252FEndpoints$252F" \
"ssh-inference-tf-2023-04-21-09-07-10-172"
def test_cloud_watch_url_transform():
url = SSHLog().get_processing_cloudwatch_url('ssh-pytorch-processing-2023-04-21-08-15-04-579')
logging.info(url)
assert url == "https://eu-west-1.console.aws.amazon.com/cloudwatch/home?region=eu-west-1#" \
"logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252FProcessingJobs$3F" \
"logStreamNameFilter$3Dssh-pytorch-processing-2023-04-21-08-15-04-579$252F"
def test_cloud_watch_url_transformer():
url = SSHLog().get_transform_cloudwatch_url('ssh-batch-transform-2023-04-21-06-45-46-843')
assert url == "https://eu-west-1.console.aws.amazon.com/cloudwatch/home?region=eu-west-1#" \
"logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252FTransformJobs$3F" \
"logStreamNameFilter$3Dssh-batch-transform-2023-04-21-06-45-46-843$252F"
def test_local_session():
from sagemaker.utils import resolve_value_from_config
from sagemaker import LocalSession
from sagemaker.config import MODEL_EXECUTION_ROLE_ARN_PATH
role: str = resolve_value_from_config(None, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=LocalSession())
assert role.startswith("arn:aws:iam")
from sagemaker.workflow.pipeline_context import PipelineSession
role: str = resolve_value_from_config(None, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=PipelineSession())
assert role.startswith("arn:aws:iam")
from sagemaker.workflow.pipeline_context import LocalPipelineSession
role: str = resolve_value_from_config(None, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=LocalPipelineSession())
assert role.startswith("arn:aws:iam")
def test_entry_point_source_dir():
entry_point = (p := Path('source_dir/inference_hf_accelerate/inference_ssh.py')).name
source_dir = str(p.parents[0])
assert entry_point == 'inference_ssh.py'
assert source_dir == 'source_dir/inference_hf_accelerate'
assert Path('source_dir/inference_hf_accelerate/') != 'source_dir/inference_hf_accelerate'
assert str(Path('source_dir/inference_hf_accelerate/')) == 'source_dir/inference_hf_accelerate'
def test_called_process_error_with_output():
got_error = False
try:
# should fail, because we're not connected to a remote kernel
subprocess.check_output("sm-local-ssh-ide run-command python --version".split(' '), stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
output = e.output.decode('latin1').strip()
logger.info(f"Got error (expected): {output}")
got_error = True
assert "ssh: connect to host localhost port 10022: Connection refused" in output
assert got_error