From adcc38e5d8ae761357e73240554c9949a7f0c765 Mon Sep 17 00:00:00 2001 From: brunopistone Date: Wed, 15 Jan 2025 18:55:50 +0000 Subject: [PATCH] refactor nproc_per_node for backwards compatibility --- src/sagemaker/remote_function/client.py | 14 +++++++------ src/sagemaker/remote_function/job.py | 20 ++++++++++++++++++- .../sagemaker/remote_function/test_job.py | 8 ++++++++ 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/remote_function/client.py b/src/sagemaker/remote_function/client.py index 2e8b2b7d09..15051dc04a 100644 --- a/src/sagemaker/remote_function/client.py +++ b/src/sagemaker/remote_function/client.py @@ -91,7 +91,7 @@ def remote( use_spot_instances=False, max_wait_time_in_seconds=None, use_torchrun=False, - nproc_per_node=1, + nproc_per_node: Optional[int] = None, ): """Decorator for running the annotated function as a SageMaker training job. @@ -284,8 +284,8 @@ def remote( use_torchrun (bool): Specifies whether to use torchrun for distributed training. Defaults to ``False``. - nproc_per_node (int): Specifies the number of processes per node for - distributed training. Defaults to ``1``. + nproc_per_node (Optional int): Specifies the number of processes per node for + distributed training. Defaults to ``None``. This is defined automatically configured on the instance type. """ @@ -320,6 +320,7 @@ def _remote(func): use_spot_instances=use_spot_instances, max_wait_time_in_seconds=max_wait_time_in_seconds, use_torchrun=use_torchrun, + nproc_per_node=nproc_per_node, ) @functools.wraps(func) @@ -536,7 +537,7 @@ def __init__( use_spot_instances=False, max_wait_time_in_seconds=None, use_torchrun=False, - nproc_per_node=1, + nproc_per_node: Optional[int] = None, ): """Constructor for RemoteExecutor @@ -729,8 +730,8 @@ def __init__( use_torchrun (bool): Specifies whether to use torchrun for distributed training. Defaults to ``False``. - nproc_per_node (int): Specifies the number of processes per node for - distributed training. Defaults to ``1``. + nproc_per_node (Optional int): Specifies the number of processes per node for + distributed training. Defaults to ``None``. This is defined automatically configured on the instance type. """ self.max_parallel_jobs = max_parallel_jobs @@ -777,6 +778,7 @@ def __init__( use_spot_instances=use_spot_instances, max_wait_time_in_seconds=max_wait_time_in_seconds, use_torchrun=use_torchrun, + nproc_per_node=nproc_per_node, ) self._state_condition = threading.Condition() diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index 51f07046a4..86574c0bdf 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -282,6 +282,7 @@ def __init__( use_spot_instances=False, max_wait_time_in_seconds=None, use_torchrun: bool = False, + nproc_per_node: Optional[int] = None, ): """Initialize a _JobSettings instance which configures the remote job. @@ -463,6 +464,13 @@ def __init__( max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job. After this amount of time Amazon SageMaker will stop waiting for managed spot training job to complete. Defaults to ``None``. + + use_torchrun (bool): Specifies whether to use torchrun for distributed training. + Defaults to ``False``. + + nproc_per_node (Optional int): Specifies the number of processes per node for + distributed training. Defaults to ``None``. + This is defined automatically configured on the instance type. """ self.sagemaker_session = sagemaker_session or Session() self.environment_variables = resolve_value_from_config( @@ -622,6 +630,7 @@ def __init__( self.tags = self.sagemaker_session._append_sagemaker_config_tags(tags, REMOTE_FUNCTION_TAGS) self.use_torchrun = use_torchrun + self.nproc_per_node = nproc_per_node @staticmethod def _get_default_image(session): @@ -749,7 +758,6 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non ) logger.info("Creating job: %s", job_name) - logger.info("Environment variables: %s", training_job_request["Environment"]) job_settings.sagemaker_session.sagemaker_client.create_training_job(**training_job_request) @@ -1027,6 +1035,7 @@ def _prepare_and_upload_runtime_scripts( s3_kms_key: str, sagemaker_session: Session, use_torchrun: bool = False, + nproc_per_node: Optional[int] = None, ): """Copy runtime scripts to a folder and upload to S3. @@ -1044,6 +1053,8 @@ def _prepare_and_upload_runtime_scripts( sagemaker_session (str): SageMaker boto client session. use_torchrun (bool): Whether to use torchrun or not. + + nproc_per_node (Optional[int]): Number of processes per node """ from sagemaker.workflow.utilities import load_step_compilation_context @@ -1068,6 +1079,12 @@ def _prepare_and_upload_runtime_scripts( if use_torchrun: entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT + if nproc_per_node is not None and nproc_per_node > 0: + entry_point_script = entry_point_script.replace( + "$SM_NPROC_PER_NODE", + str(nproc_per_node) + ) + with open(entrypoint_script_path, "w", newline="\n") as file: file.writelines(entry_point_script) @@ -1106,6 +1123,7 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str): s3_kms_key=job_settings.s3_kms_key, sagemaker_session=job_settings.sagemaker_session, use_torchrun=job_settings.use_torchrun, + nproc_per_node=job_settings.nproc_per_node, ) input_data_config = [ diff --git a/tests/unit/sagemaker/remote_function/test_job.py b/tests/unit/sagemaker/remote_function/test_job.py index 56ac67d284..48637cd19d 100644 --- a/tests/unit/sagemaker/remote_function/test_job.py +++ b/tests/unit/sagemaker/remote_function/test_job.py @@ -390,6 +390,7 @@ def test_start( s3_kms_key=None, sagemaker_session=session(), use_torchrun=False, + nproc_per_node=None, ) mock_dependency_upload.assert_called_once_with( @@ -672,6 +673,7 @@ def test_start_with_complete_job_settings( s3_kms_key=job_settings.s3_kms_key, sagemaker_session=session(), use_torchrun=False, + nproc_per_node=None, ) mock_user_workspace_upload.assert_called_once_with( @@ -843,6 +845,7 @@ def test_get_train_args_under_pipeline_context( s3_kms_key=job_settings.s3_kms_key, sagemaker_session=session(), use_torchrun=False, + nproc_per_node=None, ) mock_user_workspace_upload.assert_called_once_with( @@ -1018,6 +1021,7 @@ def test_start_with_spark( s3_kms_key=None, sagemaker_session=session(), use_torchrun=False, + nproc_per_node=None, ) session().sagemaker_client.create_training_job.assert_called_once_with( @@ -1633,6 +1637,7 @@ def test_start_with_torchrun_single_node( instance_type="ml.g5.12xlarge", encrypt_inter_container_traffic=True, use_torchrun=True, + nproc_per_node=None, ) job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) @@ -1658,6 +1663,7 @@ def test_start_with_torchrun_single_node( s3_kms_key=None, sagemaker_session=session(), use_torchrun=True, + nproc_per_node=None, ) mock_dependency_upload.assert_called_once_with( @@ -1759,6 +1765,7 @@ def test_start_with_torchrun_multi_node( instance_type="ml.g5.2xlarge", encrypt_inter_container_traffic=True, use_torchrun=True, + nproc_per_node=None, ) job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) @@ -1784,6 +1791,7 @@ def test_start_with_torchrun_multi_node( s3_kms_key=None, sagemaker_session=session(), use_torchrun=True, + nproc_per_node=None, ) mock_dependency_upload.assert_called_once_with(