Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ssh host policy #4966

Merged
merged 22 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions src/sagemaker/modules/train/container_drivers/mpi_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from __future__ import absolute_import

import os
import time
import subprocess

import time
from typing import List

from utils import logger, SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable
import paramiko
from utils import SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable, logger

FINISHED_STATUS_FILE = "/tmp/done.algo-1"
READY_FILE = "/tmp/ready.%s"
Expand Down Expand Up @@ -75,19 +75,45 @@
logger.info("Started SSH daemon.")


class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy):
"""Class to handle host key policy for SageMaker distributed training SSH connections.

Example:
>>> client = paramiko.SSHClient()
>>> client.set_missing_host_key_policy(CustomHostKeyPolicy())
>>> # Will succeed for SageMaker algorithm containers
>>> client.connect('algo-1234.internal')
>>> # Will raise SSHException for other unknown hosts
>>> client.connect('unknown-host') # raises SSHException
"""

def missing_host_key(self, client, hostname, key):
"""Accept host keys for algo-* hostnames, reject others.

Args:
client: The SSHClient instance
hostname: The hostname attempting to connect
key: The host key

Raises:
paramiko.SSHException: If hostname doesn't match algo-* pattern
"""
if hostname.startswith("algo-"):
client.get_host_keys().add(hostname, key.get_name(), key)
return
raise paramiko.SSHException(f"Unknown host key for {hostname}")


def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool:
"""Check if the connection to the provided host and port is possible."""
try:
import paramiko

logger.debug("Testing connection to host %s", host)
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(host, port=port)
client.close()
logger.info("Can connect to host %s", host)
return True
with paramiko.SSHClient() as client:
client.load_system_host_keys()
client.set_missing_host_key_policy(CustomHostKeyPolicy())
client.connect(host, port=port)
logger.info("Can connect to host %s", host)
return True
except Exception as e: # pylint: disable=W0703
logger.info("Cannot connect to host %s", host)
logger.debug(f"Connection failed with exception: {e}")
Expand Down Expand Up @@ -183,9 +209,9 @@

def write_env_vars_to_file():
"""Write environment variables to /etc/environment file."""
with open("/etc/environment", "a") as f:
with open("/etc/environment", "a", encoding="utf-8") as f:

Check warning on line 212 in src/sagemaker/modules/train/container_drivers/mpi_utils.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/modules/train/container_drivers/mpi_utils.py#L212

Added line #L212 was not covered by tests
for name in os.environ:
f.write("{}={}\n".format(name, os.environ.get(name)))
f.write(f"{name}={os.environ.get(name)}\n")

Check warning on line 214 in src/sagemaker/modules/train/container_drivers/mpi_utils.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/modules/train/container_drivers/mpi_utils.py#L214

Added line #L214 was not covered by tests


def get_mpirun_command(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""MPI Utils Unit Tests."""
from __future__ import absolute_import

import subprocess
from unittest.mock import Mock, patch

import paramiko
import pytest

# Mock the utils module before importing mpi_utils
mock_utils = Mock()
mock_utils.logger = Mock()
mock_utils.SM_EFA_NCCL_INSTANCES = []
mock_utils.SM_EFA_RDMA_INSTANCES = []
mock_utils.get_python_executable = Mock(return_value="/usr/bin/python")

with patch.dict("sys.modules", {"utils": mock_utils}):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

from sagemaker.modules.train.container_drivers.mpi_utils import (
CustomHostKeyPolicy,
_can_connect,
write_status_file_to_workers,
)

TEST_HOST = "algo-1"
TEST_WORKER = "algo-2"
TEST_STATUS_FILE = "/tmp/test-status"


def test_custom_host_key_policy_valid_hostname():
"""Test CustomHostKeyPolicy accepts algo- prefixed hostnames."""
policy = CustomHostKeyPolicy()
mock_client = Mock()
mock_key = Mock()
mock_key.get_name.return_value = "ssh-rsa"

policy.missing_host_key(mock_client, "algo-1", mock_key)

mock_client.get_host_keys.assert_called_once()
mock_client.get_host_keys().add.assert_called_once_with("algo-1", "ssh-rsa", mock_key)


def test_custom_host_key_policy_invalid_hostname():
"""Test CustomHostKeyPolicy rejects non-algo prefixed hostnames."""
policy = CustomHostKeyPolicy()
mock_client = Mock()
mock_key = Mock()

with pytest.raises(paramiko.SSHException) as exc_info:
policy.missing_host_key(mock_client, "invalid-1", mock_key)

assert "Unknown host key for invalid-1" in str(exc_info.value)
mock_client.get_host_keys.assert_not_called()


@patch("paramiko.SSHClient")
@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger")
def test_can_connect_success(mock_logger, mock_ssh_client):
"""Test successful SSH connection."""
mock_client = Mock()
mock_ssh_client.return_value.__enter__.return_value = mock_client
mock_client.connect.return_value = None # Successful connection

result = _can_connect(TEST_HOST)

assert result is True
mock_client.load_system_host_keys.assert_called_once()
mock_client.set_missing_host_key_policy.assert_called_once()
mock_client.connect.assert_called_once_with(TEST_HOST, port=22)
mock_logger.info.assert_called_with("Can connect to host %s", TEST_HOST)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this assert has flaky results, the previous asserts should be sufficient



@patch("paramiko.SSHClient")
@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger")
def test_can_connect_failure(mock_logger, mock_ssh_client):
"""Test SSH connection failure."""
mock_client = Mock()
mock_ssh_client.return_value.__enter__.return_value = mock_client
mock_client.connect.side_effect = paramiko.SSHException("Connection failed")

result = _can_connect(TEST_HOST)

assert result is False
mock_client.load_system_host_keys.assert_called_once()
mock_client.set_missing_host_key_policy.assert_called_once()
mock_client.connect.assert_called_once_with(TEST_HOST, port=22)
mock_logger.info.assert_called_with("Cannot connect to host %s", TEST_HOST)


@patch("subprocess.run")
@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger")
def test_write_status_file_to_workers_failure(mock_logger, mock_run):
"""Test failed status file writing to workers with retry timeout."""
mock_run.side_effect = subprocess.CalledProcessError(1, "ssh")

with pytest.raises(TimeoutError) as exc_info:
write_status_file_to_workers([TEST_WORKER], TEST_STATUS_FILE)

assert f"Timed out waiting for {TEST_WORKER}" in str(exc_info.value)
assert mock_run.call_count > 1 # Verifies that retries occurred
mock_logger.info.assert_any_call(f"Cannot connect to {TEST_WORKER}")


if __name__ == "__main__":
pytest.main([__file__])
Loading