Skip to content

Commit

Permalink
refactor nproc_per_node for backwards compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
brunopistone committed Jan 15, 2025
1 parent 423c585 commit adcc38e
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 7 deletions.
14 changes: 8 additions & 6 deletions src/sagemaker/remote_function/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def remote(
use_spot_instances=False,
max_wait_time_in_seconds=None,
use_torchrun=False,
nproc_per_node=1,
nproc_per_node: Optional[int] = None,
):
"""Decorator for running the annotated function as a SageMaker training job.
Expand Down Expand Up @@ -284,8 +284,8 @@ def remote(
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
Defaults to ``False``.
nproc_per_node (int): Specifies the number of processes per node for
distributed training. Defaults to ``1``.
nproc_per_node (Optional int): Specifies the number of processes per node for
distributed training. Defaults to ``None``.
This is defined automatically configured on the instance type.
"""

Expand Down Expand Up @@ -320,6 +320,7 @@ def _remote(func):
use_spot_instances=use_spot_instances,
max_wait_time_in_seconds=max_wait_time_in_seconds,
use_torchrun=use_torchrun,
nproc_per_node=nproc_per_node,
)

@functools.wraps(func)
Expand Down Expand Up @@ -536,7 +537,7 @@ def __init__(
use_spot_instances=False,
max_wait_time_in_seconds=None,
use_torchrun=False,
nproc_per_node=1,
nproc_per_node: Optional[int] = None,
):
"""Constructor for RemoteExecutor
Expand Down Expand Up @@ -729,8 +730,8 @@ def __init__(
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
Defaults to ``False``.
nproc_per_node (int): Specifies the number of processes per node for
distributed training. Defaults to ``1``.
nproc_per_node (Optional int): Specifies the number of processes per node for
distributed training. Defaults to ``None``.
This is defined automatically configured on the instance type.
"""
self.max_parallel_jobs = max_parallel_jobs
Expand Down Expand Up @@ -777,6 +778,7 @@ def __init__(
use_spot_instances=use_spot_instances,
max_wait_time_in_seconds=max_wait_time_in_seconds,
use_torchrun=use_torchrun,
nproc_per_node=nproc_per_node,
)

self._state_condition = threading.Condition()
Expand Down
20 changes: 19 additions & 1 deletion src/sagemaker/remote_function/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def __init__(
use_spot_instances=False,
max_wait_time_in_seconds=None,
use_torchrun: bool = False,
nproc_per_node: Optional[int] = None,
):
"""Initialize a _JobSettings instance which configures the remote job.
Expand Down Expand Up @@ -463,6 +464,13 @@ def __init__(
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
After this amount of time Amazon SageMaker will stop waiting for managed spot
training job to complete. Defaults to ``None``.
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
Defaults to ``False``.
nproc_per_node (Optional int): Specifies the number of processes per node for
distributed training. Defaults to ``None``.
This is defined automatically configured on the instance type.
"""
self.sagemaker_session = sagemaker_session or Session()
self.environment_variables = resolve_value_from_config(
Expand Down Expand Up @@ -622,6 +630,7 @@ def __init__(
self.tags = self.sagemaker_session._append_sagemaker_config_tags(tags, REMOTE_FUNCTION_TAGS)

self.use_torchrun = use_torchrun
self.nproc_per_node = nproc_per_node

@staticmethod
def _get_default_image(session):
Expand Down Expand Up @@ -749,7 +758,6 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
)

logger.info("Creating job: %s", job_name)
logger.info("Environment variables: %s", training_job_request["Environment"])

job_settings.sagemaker_session.sagemaker_client.create_training_job(**training_job_request)

Expand Down Expand Up @@ -1027,6 +1035,7 @@ def _prepare_and_upload_runtime_scripts(
s3_kms_key: str,
sagemaker_session: Session,
use_torchrun: bool = False,
nproc_per_node: Optional[int] = None,
):
"""Copy runtime scripts to a folder and upload to S3.
Expand All @@ -1044,6 +1053,8 @@ def _prepare_and_upload_runtime_scripts(
sagemaker_session (str): SageMaker boto client session.
use_torchrun (bool): Whether to use torchrun or not.
nproc_per_node (Optional[int]): Number of processes per node
"""

from sagemaker.workflow.utilities import load_step_compilation_context
Expand All @@ -1068,6 +1079,12 @@ def _prepare_and_upload_runtime_scripts(
if use_torchrun:
entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT

if nproc_per_node is not None and nproc_per_node > 0:
entry_point_script = entry_point_script.replace(
"$SM_NPROC_PER_NODE",
str(nproc_per_node)
)

with open(entrypoint_script_path, "w", newline="\n") as file:
file.writelines(entry_point_script)

Expand Down Expand Up @@ -1106,6 +1123,7 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
s3_kms_key=job_settings.s3_kms_key,
sagemaker_session=job_settings.sagemaker_session,
use_torchrun=job_settings.use_torchrun,
nproc_per_node=job_settings.nproc_per_node,
)

input_data_config = [
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/sagemaker/remote_function/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def test_start(
s3_kms_key=None,
sagemaker_session=session(),
use_torchrun=False,
nproc_per_node=None,
)

mock_dependency_upload.assert_called_once_with(
Expand Down Expand Up @@ -672,6 +673,7 @@ def test_start_with_complete_job_settings(
s3_kms_key=job_settings.s3_kms_key,
sagemaker_session=session(),
use_torchrun=False,
nproc_per_node=None,
)

mock_user_workspace_upload.assert_called_once_with(
Expand Down Expand Up @@ -843,6 +845,7 @@ def test_get_train_args_under_pipeline_context(
s3_kms_key=job_settings.s3_kms_key,
sagemaker_session=session(),
use_torchrun=False,
nproc_per_node=None,
)

mock_user_workspace_upload.assert_called_once_with(
Expand Down Expand Up @@ -1018,6 +1021,7 @@ def test_start_with_spark(
s3_kms_key=None,
sagemaker_session=session(),
use_torchrun=False,
nproc_per_node=None,
)

session().sagemaker_client.create_training_job.assert_called_once_with(
Expand Down Expand Up @@ -1633,6 +1637,7 @@ def test_start_with_torchrun_single_node(
instance_type="ml.g5.12xlarge",
encrypt_inter_container_traffic=True,
use_torchrun=True,
nproc_per_node=None,
)

job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4})
Expand All @@ -1658,6 +1663,7 @@ def test_start_with_torchrun_single_node(
s3_kms_key=None,
sagemaker_session=session(),
use_torchrun=True,
nproc_per_node=None,
)

mock_dependency_upload.assert_called_once_with(
Expand Down Expand Up @@ -1759,6 +1765,7 @@ def test_start_with_torchrun_multi_node(
instance_type="ml.g5.2xlarge",
encrypt_inter_container_traffic=True,
use_torchrun=True,
nproc_per_node=None,
)

job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4})
Expand All @@ -1784,6 +1791,7 @@ def test_start_with_torchrun_multi_node(
s3_kms_key=None,
sagemaker_session=session(),
use_torchrun=True,
nproc_per_node=None,
)

mock_dependency_upload.assert_called_once_with(
Expand Down

0 comments on commit adcc38e

Please sign in to comment.