Skip to content

Commit

Permalink
feature: Add heterogeneous cluster changes (#421)
Browse files Browse the repository at this point in the history
* feature: Add heterogeneous cluster changes
  • Loading branch information
satishpasumarthi authored Jul 8, 2022
1 parent 30fed8f commit 67171fa
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 17 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def read_version():
"botocore==1.19.34",
"requests-mock",
"awscli==1.18.194",
"protobuf>=3.20,<3.21"
"protobuf>=3.9.2,<3.20"
]

if sys.version_info.major > 2:
Expand Down
33 changes: 17 additions & 16 deletions src/sagemaker_tensorflow_container/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,17 +176,20 @@ def train(env, cmd_args):
env_vars = env.to_env_vars()

# Setup
if parameter_server_enabled:
if env.current_instance_group in env.distribution_instance_groups:
if parameter_server_enabled:

tf_config = _build_tf_config_for_ps(hosts=env.hosts, current_host=env.current_host)
logger.info("Running distributed training job with parameter servers")
tf_config = _build_tf_config_for_ps(hosts=env.distribution_hosts, current_host=env.current_host)
logger.info("Running distributed training job with parameter servers")

elif multi_worker_mirrored_strategy_enabled:
elif multi_worker_mirrored_strategy_enabled:

env_vars["TF_CONFIG"] = json.dumps(
_build_tf_config_for_mwms(hosts=env.hosts, current_host=env.current_host)
)
logger.info("Running distributed training job with multi_worker_mirrored_strategy setup")
env_vars["TF_CONFIG"] = json.dumps(
_build_tf_config_for_mwms(hosts=env.distribution_hosts, current_host=env.current_host)
)
logger.info("Running distributed training job with multi_worker_mirrored_strategy setup")

runner_type = runner.ProcessRunnerType

# Run
if parameter_server_enabled:
Expand All @@ -200,15 +203,13 @@ def train(env, cmd_args):
_wait_until_master_is_down(env.hosts[0])

else:
if env.current_instance_group in env.distribution_instance_groups:
mpi_enabled = env.additional_framework_parameters.get("sagemaker_mpi_enabled")

mpi_enabled = env.additional_framework_parameters.get("sagemaker_mpi_enabled")

if mpi_enabled:
runner_type = runner.MPIRunnerType
elif sagemaker_distributed_dataparallel_enabled:
runner_type = runner.SMDataParallelRunnerType
else:
runner_type = runner.ProcessRunnerType
if mpi_enabled:
runner_type = runner.MPIRunnerType
elif sagemaker_distributed_dataparallel_enabled:
runner_type = runner.SMDataParallelRunnerType

entry_point.run(
uri=env.module_dir,
Expand Down
5 changes: 5 additions & 0 deletions test/unit/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def distributed_training_env():
env = simple_training_env()

env.hosts = HOST_LIST
env.current_instance_group = "test1"
env.distribution_hosts = ["host1", "host2"]
env.distribution_instance_groups = ["test1"]
env.additional_framework_parameters = {training.SAGEMAKER_PARAMETER_SERVER_ENABLED: True}
return env

Expand All @@ -63,6 +66,8 @@ def single_machine_training_env():

def simple_training_env():
env = MagicMock()
env.current_instance_group = "test1"
env.distribution_instance_groups = ["test1"]
env.module_dir = MODULE_DIR
env.user_entry_point = MODULE_NAME
env.hyperparameters = {"model_dir": MODEL_DIR}
Expand Down

0 comments on commit 67171fa

Please sign in to comment.