Skip to content

Commit

Permalink
Merge branch 'aws:master' into rsareddy-dev
Browse files Browse the repository at this point in the history
  • Loading branch information
rsareddy0329 authored Jan 30, 2025
2 parents 9321367 + c753da0 commit f972222
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 14 deletions.
54 changes: 40 additions & 14 deletions src/sagemaker/modules/train/container_drivers/mpi_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from __future__ import absolute_import

import os
import time
import subprocess

import time
from typing import List

from utils import logger, SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable
import paramiko
from utils import SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable, logger

FINISHED_STATUS_FILE = "/tmp/done.algo-1"
READY_FILE = "/tmp/ready.%s"
Expand Down Expand Up @@ -75,19 +75,45 @@ def start_sshd_daemon():
logger.info("Started SSH daemon.")


class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy):
"""Class to handle host key policy for SageMaker distributed training SSH connections.
Example:
>>> client = paramiko.SSHClient()
>>> client.set_missing_host_key_policy(CustomHostKeyPolicy())
>>> # Will succeed for SageMaker algorithm containers
>>> client.connect('algo-1234.internal')
>>> # Will raise SSHException for other unknown hosts
>>> client.connect('unknown-host') # raises SSHException
"""

def missing_host_key(self, client, hostname, key):
"""Accept host keys for algo-* hostnames, reject others.
Args:
client: The SSHClient instance
hostname: The hostname attempting to connect
key: The host key
Raises:
paramiko.SSHException: If hostname doesn't match algo-* pattern
"""
if hostname.startswith("algo-"):
client.get_host_keys().add(hostname, key.get_name(), key)
return
raise paramiko.SSHException(f"Unknown host key for {hostname}")


def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool:
"""Check if the connection to the provided host and port is possible."""
try:
import paramiko

logger.debug("Testing connection to host %s", host)
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(host, port=port)
client.close()
logger.info("Can connect to host %s", host)
return True
with paramiko.SSHClient() as client:
client.load_system_host_keys()
client.set_missing_host_key_policy(CustomHostKeyPolicy())
client.connect(host, port=port)
logger.info("Can connect to host %s", host)
return True
except Exception as e: # pylint: disable=W0703
logger.info("Cannot connect to host %s", host)
logger.debug(f"Connection failed with exception: {e}")
Expand Down Expand Up @@ -183,9 +209,9 @@ def validate_smddpmprun() -> bool:

def write_env_vars_to_file():
"""Write environment variables to /etc/environment file."""
with open("/etc/environment", "a") as f:
with open("/etc/environment", "a", encoding="utf-8") as f:
for name in os.environ:
f.write("{}={}\n".format(name, os.environ.get(name)))
f.write(f"{name}={os.environ.get(name)}\n")


def get_mpirun_command(
Expand Down
113 changes: 113 additions & 0 deletions tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""MPI Utils Unit Tests."""
from __future__ import absolute_import

import subprocess
from unittest.mock import Mock, patch

import paramiko
import pytest

# Mock the utils module before importing mpi_utils
mock_utils = Mock()
mock_utils.logger = Mock()
mock_utils.SM_EFA_NCCL_INSTANCES = []
mock_utils.SM_EFA_RDMA_INSTANCES = []
mock_utils.get_python_executable = Mock(return_value="/usr/bin/python")

with patch.dict("sys.modules", {"utils": mock_utils}):
from sagemaker.modules.train.container_drivers.mpi_utils import (
CustomHostKeyPolicy,
_can_connect,
write_status_file_to_workers,
)

TEST_HOST = "algo-1"
TEST_WORKER = "algo-2"
TEST_STATUS_FILE = "/tmp/test-status"


def test_custom_host_key_policy_valid_hostname():
"""Test CustomHostKeyPolicy accepts algo- prefixed hostnames."""
policy = CustomHostKeyPolicy()
mock_client = Mock()
mock_key = Mock()
mock_key.get_name.return_value = "ssh-rsa"

policy.missing_host_key(mock_client, "algo-1", mock_key)

mock_client.get_host_keys.assert_called_once()
mock_client.get_host_keys().add.assert_called_once_with("algo-1", "ssh-rsa", mock_key)


def test_custom_host_key_policy_invalid_hostname():
"""Test CustomHostKeyPolicy rejects non-algo prefixed hostnames."""
policy = CustomHostKeyPolicy()
mock_client = Mock()
mock_key = Mock()

with pytest.raises(paramiko.SSHException) as exc_info:
policy.missing_host_key(mock_client, "invalid-1", mock_key)

assert "Unknown host key for invalid-1" in str(exc_info.value)
mock_client.get_host_keys.assert_not_called()


@patch("paramiko.SSHClient")
@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger")
def test_can_connect_success(mock_logger, mock_ssh_client):
"""Test successful SSH connection."""
mock_client = Mock()
mock_ssh_client.return_value.__enter__.return_value = mock_client
mock_client.connect.return_value = None # Successful connection

result = _can_connect(TEST_HOST)

assert result is True
mock_client.load_system_host_keys.assert_called_once()
mock_client.set_missing_host_key_policy.assert_called_once()
mock_client.connect.assert_called_once_with(TEST_HOST, port=22)


@patch("paramiko.SSHClient")
@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger")
def test_can_connect_failure(mock_logger, mock_ssh_client):
"""Test SSH connection failure."""
mock_client = Mock()
mock_ssh_client.return_value.__enter__.return_value = mock_client
mock_client.connect.side_effect = paramiko.SSHException("Connection failed")

result = _can_connect(TEST_HOST)

assert result is False
mock_client.load_system_host_keys.assert_called_once()
mock_client.set_missing_host_key_policy.assert_called_once()
mock_client.connect.assert_called_once_with(TEST_HOST, port=22)


@patch("subprocess.run")
@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger")
def test_write_status_file_to_workers_failure(mock_logger, mock_run):
"""Test failed status file writing to workers with retry timeout."""
mock_run.side_effect = subprocess.CalledProcessError(1, "ssh")

with pytest.raises(TimeoutError) as exc_info:
write_status_file_to_workers([TEST_WORKER], TEST_STATUS_FILE)

assert f"Timed out waiting for {TEST_WORKER}" in str(exc_info.value)
assert mock_run.call_count > 1 # Verifies that retries occurred


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit f972222

Please sign in to comment.