Skip to content

Commit 9b2301c

Browse files
authored
Merge branch 'master' into canary_action
2 parents 5754c77 + a7baead commit 9b2301c

File tree

33 files changed

+2046
-88
lines changed

33 files changed

+2046
-88
lines changed

CHANGELOG.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,40 @@
11
# Changelog
22

3+
## v2.239.0 (2025-02-01)
4+
5+
### Features
6+
7+
* Add support for deepseek recipes
8+
9+
### Bug Fixes and Other Changes
10+
11+
* mpirun protocol - distributed training with @remote decorator
12+
* Allow telemetry only in supported regions
13+
* Fix ssh host policy
14+
15+
## v2.238.0 (2025-01-29)
16+
17+
### Features
18+
19+
* use jumpstart deployment config image as default optimization image
20+
21+
### Bug Fixes and Other Changes
22+
23+
* chore: add new images for HF TGI
24+
* update image_uri_configs 01-29-2025 06:18:08 PST
25+
* skip TF tests for unsupported versions
26+
* Merge branch 'master-rba' into local_merge
27+
* Add missing attributes to local resourceconfig
28+
* update image_uri_configs 01-27-2025 06:18:13 PST
29+
* update image_uri_configs 01-24-2025 06:18:11 PST
30+
* add missing schema definition in docs
31+
* Omegaconf upgrade
32+
* SageMaker @remote function: Added multi-node functionality
33+
* remove option
34+
* fix typo
35+
* fix tests
36+
* Add an option for user to remove inputs and container artifacts when using local model trainer
37+
338
## v2.237.3 (2025-01-09)
439

540
### Bug Fixes and Other Changes

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.237.4.dev0
1+
2.239.1.dev0

src/sagemaker/fw_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@
152152
"2.1.0",
153153
"2.1.2",
154154
"2.2.0",
155+
"2.3.0",
155156
"2.3.1",
156157
"2.4.1",
157158
]

src/sagemaker/image_uri_config/huggingface-llm.json

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
"1.2": "1.2.0",
1313
"1.3": "1.3.3",
1414
"1.4": "1.4.5",
15-
"2.0": "2.3.1"
15+
"2.0": "2.4.0",
16+
"3.0": "3.0.1"
1617
},
1718
"versions": {
1819
"0.6.0": {
@@ -766,6 +767,100 @@
766767
"container_version": {
767768
"gpu": "cu124-ubuntu22.04"
768769
}
770+
},
771+
"2.4.0": {
772+
"py_versions": [
773+
"py311"
774+
],
775+
"registries": {
776+
"af-south-1": "626614931356",
777+
"il-central-1": "780543022126",
778+
"ap-east-1": "871362719292",
779+
"ap-northeast-1": "763104351884",
780+
"ap-northeast-2": "763104351884",
781+
"ap-northeast-3": "364406365360",
782+
"ap-south-1": "763104351884",
783+
"ap-south-2": "772153158452",
784+
"ap-southeast-1": "763104351884",
785+
"ap-southeast-2": "763104351884",
786+
"ap-southeast-3": "907027046896",
787+
"ap-southeast-4": "457447274322",
788+
"ca-central-1": "763104351884",
789+
"cn-north-1": "727897471807",
790+
"cn-northwest-1": "727897471807",
791+
"eu-central-1": "763104351884",
792+
"eu-central-2": "380420809688",
793+
"eu-north-1": "763104351884",
794+
"eu-west-1": "763104351884",
795+
"eu-west-2": "763104351884",
796+
"eu-west-3": "763104351884",
797+
"eu-south-1": "692866216735",
798+
"eu-south-2": "503227376785",
799+
"me-south-1": "217643126080",
800+
"me-central-1": "914824155844",
801+
"sa-east-1": "763104351884",
802+
"us-east-1": "763104351884",
803+
"us-east-2": "763104351884",
804+
"us-gov-east-1": "446045086412",
805+
"us-gov-west-1": "442386744353",
806+
"us-iso-east-1": "886529160074",
807+
"us-isob-east-1": "094389454867",
808+
"us-west-1": "763104351884",
809+
"us-west-2": "763104351884",
810+
"ca-west-1": "204538143572"
811+
},
812+
"tag_prefix": "2.4.0-tgi2.4.0",
813+
"repository": "huggingface-pytorch-tgi-inference",
814+
"container_version": {
815+
"gpu": "cu124-ubuntu22.04-v2.2"
816+
}
817+
},
818+
"3.0.1": {
819+
"py_versions": [
820+
"py311"
821+
],
822+
"registries": {
823+
"af-south-1": "626614931356",
824+
"il-central-1": "780543022126",
825+
"ap-east-1": "871362719292",
826+
"ap-northeast-1": "763104351884",
827+
"ap-northeast-2": "763104351884",
828+
"ap-northeast-3": "364406365360",
829+
"ap-south-1": "763104351884",
830+
"ap-south-2": "772153158452",
831+
"ap-southeast-1": "763104351884",
832+
"ap-southeast-2": "763104351884",
833+
"ap-southeast-3": "907027046896",
834+
"ap-southeast-4": "457447274322",
835+
"ca-central-1": "763104351884",
836+
"cn-north-1": "727897471807",
837+
"cn-northwest-1": "727897471807",
838+
"eu-central-1": "763104351884",
839+
"eu-central-2": "380420809688",
840+
"eu-north-1": "763104351884",
841+
"eu-west-1": "763104351884",
842+
"eu-west-2": "763104351884",
843+
"eu-west-3": "763104351884",
844+
"eu-south-1": "692866216735",
845+
"eu-south-2": "503227376785",
846+
"me-south-1": "217643126080",
847+
"me-central-1": "914824155844",
848+
"sa-east-1": "763104351884",
849+
"us-east-1": "763104351884",
850+
"us-east-2": "763104351884",
851+
"us-gov-east-1": "446045086412",
852+
"us-gov-west-1": "442386744353",
853+
"us-iso-east-1": "886529160074",
854+
"us-isob-east-1": "094389454867",
855+
"us-west-1": "763104351884",
856+
"us-west-2": "763104351884",
857+
"ca-west-1": "204538143572"
858+
},
859+
"tag_prefix": "2.4.0-tgi3.0.1",
860+
"repository": "huggingface-pytorch-tgi-inference",
861+
"container_version": {
862+
"gpu": "cu124-ubuntu22.04-v2.1"
863+
}
769864
}
770865
}
771866
}

src/sagemaker/image_uri_config/huggingface.json

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
"4.17": "4.17.0",
1414
"4.26": "4.26.0",
1515
"4.28": "4.28.1",
16-
"4.36": "4.36.0"
16+
"4.36": "4.36.0",
17+
"4.46": "4.46.1"
1718
},
1819
"versions": {
1920
"4.4.2": {
@@ -1018,6 +1019,53 @@
10181019
"gpu": "cu121-ubuntu20.04"
10191020
}
10201021
}
1022+
},
1023+
"4.46.1": {
1024+
"version_aliases": {
1025+
"pytorch2.3": "pytorch2.3.0"
1026+
},
1027+
"pytorch2.3.0": {
1028+
"py_versions": [
1029+
"py311"
1030+
],
1031+
"registries": {
1032+
"af-south-1": "626614931356",
1033+
"il-central-1": "780543022126",
1034+
"ap-east-1": "871362719292",
1035+
"ap-northeast-1": "763104351884",
1036+
"ap-northeast-2": "763104351884",
1037+
"ap-northeast-3": "364406365360",
1038+
"ap-south-1": "763104351884",
1039+
"ap-southeast-1": "763104351884",
1040+
"ap-southeast-2": "763104351884",
1041+
"ap-southeast-3": "907027046896",
1042+
"ca-central-1": "763104351884",
1043+
"cn-north-1": "727897471807",
1044+
"cn-northwest-1": "727897471807",
1045+
"eu-central-1": "763104351884",
1046+
"eu-north-1": "763104351884",
1047+
"eu-west-1": "763104351884",
1048+
"eu-west-2": "763104351884",
1049+
"eu-west-3": "763104351884",
1050+
"eu-south-1": "692866216735",
1051+
"me-south-1": "217643126080",
1052+
"me-central-1": "914824155844",
1053+
"sa-east-1": "763104351884",
1054+
"us-east-1": "763104351884",
1055+
"us-east-2": "763104351884",
1056+
"us-gov-east-1": "446045086412",
1057+
"us-gov-west-1": "442386744353",
1058+
"us-iso-east-1": "886529160074",
1059+
"us-isob-east-1": "094389454867",
1060+
"us-west-1": "763104351884",
1061+
"us-west-2": "763104351884",
1062+
"ca-west-1": "204538143572"
1063+
},
1064+
"repository": "huggingface-pytorch-training",
1065+
"container_version": {
1066+
"gpu": "cu121-ubuntu20.04"
1067+
}
1068+
}
10211069
}
10221070
}
10231071
},

src/sagemaker/image_uri_config/sagemaker-base-python.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"ap-southeast-1": "492261229750",
1212
"ap-southeast-2": "452832661640",
1313
"ap-southeast-3": "276181064229",
14+
"ap-southeast-5": "148761635175",
1415
"ca-central-1": "310906938811",
1516
"cn-north-1": "390048526115",
1617
"cn-northwest-1": "390780980154",

src/sagemaker/modules/train/container_drivers/mpi_utils.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
from __future__ import absolute_import
1515

1616
import os
17-
import time
1817
import subprocess
19-
18+
import time
2019
from typing import List
2120

22-
from utils import logger, SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable
21+
import paramiko
22+
from utils import SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable, logger
2323

2424
FINISHED_STATUS_FILE = "/tmp/done.algo-1"
2525
READY_FILE = "/tmp/ready.%s"
@@ -75,19 +75,45 @@ def start_sshd_daemon():
7575
logger.info("Started SSH daemon.")
7676

7777

78+
class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy):
79+
"""Class to handle host key policy for SageMaker distributed training SSH connections.
80+
81+
Example:
82+
>>> client = paramiko.SSHClient()
83+
>>> client.set_missing_host_key_policy(CustomHostKeyPolicy())
84+
>>> # Will succeed for SageMaker algorithm containers
85+
>>> client.connect('algo-1234.internal')
86+
>>> # Will raise SSHException for other unknown hosts
87+
>>> client.connect('unknown-host') # raises SSHException
88+
"""
89+
90+
def missing_host_key(self, client, hostname, key):
91+
"""Accept host keys for algo-* hostnames, reject others.
92+
93+
Args:
94+
client: The SSHClient instance
95+
hostname: The hostname attempting to connect
96+
key: The host key
97+
98+
Raises:
99+
paramiko.SSHException: If hostname doesn't match algo-* pattern
100+
"""
101+
if hostname.startswith("algo-"):
102+
client.get_host_keys().add(hostname, key.get_name(), key)
103+
return
104+
raise paramiko.SSHException(f"Unknown host key for {hostname}")
105+
106+
78107
def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool:
79108
"""Check if the connection to the provided host and port is possible."""
80109
try:
81-
import paramiko
82-
83110
logger.debug("Testing connection to host %s", host)
84-
client = paramiko.SSHClient()
85-
client.load_system_host_keys()
86-
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
87-
client.connect(host, port=port)
88-
client.close()
89-
logger.info("Can connect to host %s", host)
90-
return True
111+
with paramiko.SSHClient() as client:
112+
client.load_system_host_keys()
113+
client.set_missing_host_key_policy(CustomHostKeyPolicy())
114+
client.connect(host, port=port)
115+
logger.info("Can connect to host %s", host)
116+
return True
91117
except Exception as e: # pylint: disable=W0703
92118
logger.info("Cannot connect to host %s", host)
93119
logger.debug(f"Connection failed with exception: {e}")
@@ -183,9 +209,9 @@ def validate_smddpmprun() -> bool:
183209

184210
def write_env_vars_to_file():
185211
"""Write environment variables to /etc/environment file."""
186-
with open("/etc/environment", "a") as f:
212+
with open("/etc/environment", "a", encoding="utf-8") as f:
187213
for name in os.environ:
188-
f.write("{}={}\n".format(name, os.environ.get(name)))
214+
f.write(f"{name}={os.environ.get(name)}\n")
189215

190216

191217
def get_mpirun_command(

src/sagemaker/modules/train/sm_recipes/utils.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,27 @@ def _register_custom_resolvers():
125125
OmegaConf.register_new_resolver("add", lambda *numbers: sum(numbers))
126126

127127

128+
def _get_trainining_recipe_gpu_model_name_and_script(model_type: str):
129+
"""Get the model base name and script for the training recipe."""
130+
131+
model_type_to_script = {
132+
"llama_v3": ("llama", "llama_pretrain.py"),
133+
"mistral": ("mistral", "mistral_pretrain.py"),
134+
"mixtral": ("mixtral", "mixtral_pretrain.py"),
135+
"deepseek": ("deepseek", "deepseek_pretrain.py"),
136+
}
137+
138+
for key in model_type_to_script:
139+
if model_type.startswith(key):
140+
model_type = key
141+
break
142+
143+
if model_type not in model_type_to_script:
144+
raise ValueError(f"Model type {model_type} not supported")
145+
146+
return model_type_to_script[model_type][0], model_type_to_script[model_type][1]
147+
148+
128149
def _configure_gpu_args(
129150
training_recipes_cfg: Dict[str, Any],
130151
region_name: str,
@@ -140,24 +161,16 @@ def _configure_gpu_args(
140161
)
141162
_run_clone_command_silent(adapter_repo, recipe_train_dir.name)
142163

143-
model_type_to_entry = {
144-
"llama_v3": ("llama", "llama_pretrain.py"),
145-
"mistral": ("mistral", "mistral_pretrain.py"),
146-
"mixtral": ("mixtral", "mixtral_pretrain.py"),
147-
}
148-
149164
if "model" not in recipe:
150165
raise ValueError("Supplied recipe does not contain required field model.")
151166
if "model_type" not in recipe["model"]:
152167
raise ValueError("Supplied recipe does not contain required field model_type.")
153168
model_type = recipe["model"]["model_type"]
154-
if model_type not in model_type_to_entry:
155-
raise ValueError(f"Model type {model_type} not supported")
156169

157-
source_code.source_dir = os.path.join(
158-
recipe_train_dir.name, "examples", model_type_to_entry[model_type][0]
159-
)
160-
source_code.entry_script = model_type_to_entry[model_type][1]
170+
model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(model_type)
171+
172+
source_code.source_dir = os.path.join(recipe_train_dir.name, "examples", model_base_name)
173+
source_code.entry_script = script
161174

162175
gpu_image_cfg = training_recipes_cfg.get("gpu_image")
163176
if isinstance(gpu_image_cfg, str):

src/sagemaker/pytorch/estimator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,20 @@ def _get_training_recipe_gpu_script(code_dir, recipe, source_dir):
9595
"llama_v3": ("llama", "llama_pretrain.py"),
9696
"mistral": ("mistral", "mistral_pretrain.py"),
9797
"mixtral": ("mixtral", "mixtral_pretrain.py"),
98+
"deepseek": ("deepseek", "deepseek_pretrain.py"),
9899
}
99100

100101
if "model" not in recipe:
101102
raise ValueError("Supplied recipe does not contain required field model.")
102103
if "model_type" not in recipe["model"]:
103104
raise ValueError("Supplied recipe does not contain required field model_type.")
104105
model_type = recipe["model"]["model_type"]
106+
107+
for key in model_type_to_script:
108+
if model_type.startswith(key):
109+
model_type = key
110+
break
111+
105112
if model_type not in model_type_to_script:
106113
raise ValueError(f"Model type {model_type} not supported")
107114

0 commit comments

Comments
 (0)