Skip to content

Commit b695743

Browse files
committed
feat: Add support for deeepseek recipes
1 parent c753da0 commit b695743

File tree

3 files changed

+67
-12
lines changed

3 files changed

+67
-12
lines changed

src/sagemaker/modules/train/sm_recipes/utils.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,27 @@ def _register_custom_resolvers():
125125
OmegaConf.register_new_resolver("add", lambda *numbers: sum(numbers))
126126

127127

128+
def _get_trainining_recipe_gpu_model_name_and_script(model_type: str):
129+
"""Get the model base name and script for the training recipe."""
130+
131+
model_type_to_script = {
132+
"llama_v3": ("llama", "llama_pretrain.py"),
133+
"mistral": ("mistral", "mistral_pretrain.py"),
134+
"mixtral": ("mixtral", "mixtral_pretrain.py"),
135+
"deepseek": ("deepseek", "deepseek_pretrain.py"),
136+
}
137+
138+
for key in model_type_to_script.keys():
139+
if model_type.startswith(key):
140+
model_type = key
141+
break
142+
143+
if model_type not in model_type_to_script:
144+
raise ValueError(f"Model type {model_type} not supported")
145+
146+
return model_type_to_script[model_type][0], model_type_to_script[model_type][1]
147+
148+
128149
def _configure_gpu_args(
129150
training_recipes_cfg: Dict[str, Any],
130151
region_name: str,
@@ -140,24 +161,16 @@ def _configure_gpu_args(
140161
)
141162
_run_clone_command_silent(adapter_repo, recipe_train_dir.name)
142163

143-
model_type_to_entry = {
144-
"llama_v3": ("llama", "llama_pretrain.py"),
145-
"mistral": ("mistral", "mistral_pretrain.py"),
146-
"mixtral": ("mixtral", "mixtral_pretrain.py"),
147-
}
148-
149164
if "model" not in recipe:
150165
raise ValueError("Supplied recipe does not contain required field model.")
151166
if "model_type" not in recipe["model"]:
152167
raise ValueError("Supplied recipe does not contain required field model_type.")
153168
model_type = recipe["model"]["model_type"]
154-
if model_type not in model_type_to_entry:
155-
raise ValueError(f"Model type {model_type} not supported")
156169

157-
source_code.source_dir = os.path.join(
158-
recipe_train_dir.name, "examples", model_type_to_entry[model_type][0]
159-
)
160-
source_code.entry_script = model_type_to_entry[model_type][1]
170+
model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(model_type)
171+
172+
source_code.source_dir = os.path.join(recipe_train_dir.name, "examples", model_base_name)
173+
source_code.entry_script = script
161174

162175
gpu_image_cfg = training_recipes_cfg.get("gpu_image")
163176
if isinstance(gpu_image_cfg, str):

src/sagemaker/pytorch/estimator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,20 @@ def _get_training_recipe_gpu_script(code_dir, recipe, source_dir):
9595
"llama_v3": ("llama", "llama_pretrain.py"),
9696
"mistral": ("mistral", "mistral_pretrain.py"),
9797
"mixtral": ("mixtral", "mixtral_pretrain.py"),
98+
"deepseek": ("deepseek", "deepseek_pretrain.py"),
9899
}
99100

100101
if "model" not in recipe:
101102
raise ValueError("Supplied recipe does not contain required field model.")
102103
if "model_type" not in recipe["model"]:
103104
raise ValueError("Supplied recipe does not contain required field model_type.")
104105
model_type = recipe["model"]["model_type"]
106+
107+
for key in model_type_to_script.keys():
108+
if model_type.startswith(key):
109+
model_type = key
110+
break
111+
105112
if model_type not in model_type_to_script:
106113
raise ValueError(f"Model type {model_type} not supported")
107114

tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
_load_recipes_cfg,
2727
_configure_gpu_args,
2828
_configure_trainium_args,
29+
_get_trainining_recipe_gpu_model_name_and_script,
2930
)
3031
from sagemaker.modules.utils import _run_clone_command_silent
3132
from sagemaker.modules.configs import Compute
@@ -178,3 +179,37 @@ def test_get_args_from_recipe_compute(
178179
assert mock_gpu_args.call_count == 0
179180
assert mock_trainium_args.call_count == 0
180181
assert args is None
182+
183+
@pytest.mark.parametrize(
184+
"test_case",
185+
[
186+
{
187+
"model_type": "llama_v3",
188+
"script": "llama_pretrain.py",
189+
"model_base_name": "llama_v3",
190+
},
191+
{
192+
"model_type": "mistral",
193+
"script": "mistral_pretrain.py",
194+
"model_base_name": "mistral",
195+
},
196+
{
197+
"model_type": "deepseek_llamav3",
198+
"script": "deepseek_pretrain.py",
199+
"model_base_name": "deepseek",
200+
},
201+
{
202+
"model_type": "deepseek_qwenv2",
203+
"script": "deepseek_pretrain.py",
204+
"model_base_name": "deepseek",
205+
},
206+
],
207+
)
208+
def test_get_trainining_recipe_gpu_model_name_and_script(test_case):
209+
model_type = test_case["model_type"]
210+
script = test_case["script"]
211+
model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(
212+
model_type, script
213+
)
214+
assert model_base_name == test_case["model_base_name"]
215+
assert script == test_case["script"]

0 commit comments

Comments
 (0)