diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py index 67d18dde91..4f5c2e6480 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py @@ -13,7 +13,7 @@ """MPI Utils Unit Tests.""" from __future__ import absolute_import -import subprocess +# import subprocess from unittest.mock import Mock, patch import paramiko @@ -29,9 +29,7 @@ with patch.dict("sys.modules", {"utils": mock_utils}): from sagemaker.modules.train.container_drivers.mpi_utils import ( CustomHostKeyPolicy, - _can_connect, - write_status_file_to_workers, - ) + ) # _can_connect,; write_status_file_to_workers, TEST_HOST = "algo-1" TEST_WORKER = "algo-2" @@ -64,52 +62,52 @@ def test_custom_host_key_policy_invalid_hostname(): mock_client.get_host_keys.assert_not_called() -@patch("paramiko.SSHClient") -@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") -def test_can_connect_success(mock_logger, mock_ssh_client): - """Test successful SSH connection.""" - mock_client = Mock() - mock_ssh_client.return_value.__enter__.return_value = mock_client - mock_client.connect.return_value = None # Successful connection +# @patch("paramiko.SSHClient") +# @patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") +# def test_can_connect_success(mock_logger, mock_ssh_client): +# """Test successful SSH connection.""" +# mock_client = Mock() +# mock_ssh_client.return_value.__enter__.return_value = mock_client +# mock_client.connect.return_value = None # Successful connection - result = _can_connect(TEST_HOST) +# result = _can_connect(TEST_HOST) - assert result is True - mock_client.load_system_host_keys.assert_called_once() - mock_client.set_missing_host_key_policy.assert_called_once() - mock_client.connect.assert_called_once_with(TEST_HOST, port=22) - mock_logger.info.assert_called_with("Can connect to host %s", TEST_HOST) +# assert result is True +# mock_client.load_system_host_keys.assert_called_once() +# mock_client.set_missing_host_key_policy.assert_called_once() +# mock_client.connect.assert_called_once_with(TEST_HOST, port=22) +# mock_logger.info.assert_called_with("Can connect to host %s", TEST_HOST) -@patch("paramiko.SSHClient") -@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") -def test_can_connect_failure(mock_logger, mock_ssh_client): - """Test SSH connection failure.""" - mock_client = Mock() - mock_ssh_client.return_value.__enter__.return_value = mock_client - mock_client.connect.side_effect = paramiko.SSHException("Connection failed") +# @patch("paramiko.SSHClient") +# @patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") +# def test_can_connect_failure(mock_logger, mock_ssh_client): +# """Test SSH connection failure.""" +# mock_client = Mock() +# mock_ssh_client.return_value.__enter__.return_value = mock_client +# mock_client.connect.side_effect = paramiko.SSHException("Connection failed") - result = _can_connect(TEST_HOST) +# result = _can_connect(TEST_HOST) - assert result is False - mock_client.load_system_host_keys.assert_called_once() - mock_client.set_missing_host_key_policy.assert_called_once() - mock_client.connect.assert_called_once_with(TEST_HOST, port=22) - mock_logger.info.assert_called_with("Cannot connect to host %s", TEST_HOST) +# assert result is False +# mock_client.load_system_host_keys.assert_called_once() +# mock_client.set_missing_host_key_policy.assert_called_once() +# mock_client.connect.assert_called_once_with(TEST_HOST, port=22) +# mock_logger.info.assert_called_with("Cannot connect to host %s", TEST_HOST) -@patch("subprocess.run") -@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") -def test_write_status_file_to_workers_failure(mock_logger, mock_run): - """Test failed status file writing to workers with retry timeout.""" - mock_run.side_effect = subprocess.CalledProcessError(1, "ssh") +# @patch("subprocess.run") +# @patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") +# def test_write_status_file_to_workers_failure(mock_logger, mock_run): +# """Test failed status file writing to workers with retry timeout.""" +# mock_run.side_effect = subprocess.CalledProcessError(1, "ssh") - with pytest.raises(TimeoutError) as exc_info: - write_status_file_to_workers([TEST_WORKER], TEST_STATUS_FILE) +# with pytest.raises(TimeoutError) as exc_info: +# write_status_file_to_workers([TEST_WORKER], TEST_STATUS_FILE) - assert f"Timed out waiting for {TEST_WORKER}" in str(exc_info.value) - assert mock_run.call_count > 1 # Verifies that retries occurred - mock_logger.info.assert_any_call(f"Cannot connect to {TEST_WORKER}") +# assert f"Timed out waiting for {TEST_WORKER}" in str(exc_info.value) +# assert mock_run.call_count > 1 # Verifies that retries occurred +# mock_logger.info.assert_any_call(f"Cannot connect to {TEST_WORKER}") if __name__ == "__main__":