Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ssh host policy #4966

Merged
merged 22 commits into from
Jan 30, 2025
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions src/sagemaker/modules/train/container_drivers/mpi_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from __future__ import absolute_import

import os
import time
import subprocess

import time

Check warning on line 18 in src/sagemaker/modules/train/container_drivers/mpi_utils.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/modules/train/container_drivers/mpi_utils.py#L18

Added line #L18 was not covered by tests
from typing import List

from utils import logger, SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable
import paramiko
from utils import SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable, logger

Check warning on line 22 in src/sagemaker/modules/train/container_drivers/mpi_utils.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/modules/train/container_drivers/mpi_utils.py#L21-L22

Added lines #L21 - L22 were not covered by tests

FINISHED_STATUS_FILE = "/tmp/done.algo-1"
READY_FILE = "/tmp/ready.%s"
Expand Down Expand Up @@ -75,6 +75,24 @@
logger.info("Started SSH daemon.")


class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy):
def missing_host_key(self, client, hostname, key):

Check warning on line 79 in src/sagemaker/modules/train/container_drivers/mpi_utils.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/modules/train/container_drivers/mpi_utils.py#L78-L79

Added lines #L78 - L79 were not covered by tests
"""Accept host keys for algo-* hostnames, reject others.

Args:
client: The SSHClient instance
hostname: The hostname attempting to connect
key: The host key

Raises:
paramiko.SSHException: If hostname doesn't match algo-* pattern
"""
if hostname.startswith("algo-"):
client.get_host_keys().add(hostname, key.get_name(), key)
return
raise paramiko.SSHException(f"Unknown host key for {hostname}")

Check warning on line 93 in src/sagemaker/modules/train/container_drivers/mpi_utils.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/modules/train/container_drivers/mpi_utils.py#L90-L93

Added lines #L90 - L93 were not covered by tests


def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool:
"""Check if the connection to the provided host and port is possible."""
try:
Expand All @@ -83,7 +101,7 @@
logger.debug("Testing connection to host %s", host)
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.set_missing_host_key_policy(CustomHostKeyPolicy())

Check warning on line 104 in src/sagemaker/modules/train/container_drivers/mpi_utils.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/modules/train/container_drivers/mpi_utils.py#L104

Added line #L104 was not covered by tests
client.connect(host, port=port)
client.close()
logger.info("Can connect to host %s", host)
Expand Down
Loading