Skip to content

Commit

Permalink
Fixed hyperopt trial syncing to remote filesystems for Ray 2.0 (#2617)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Oct 11, 2022
1 parent dad9171 commit d8a0d8f
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 58 deletions.
14 changes: 8 additions & 6 deletions ludwig/hyperopt/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@
if _ray_200:
from ray.air import Checkpoint
from ray.tune.search import SEARCH_ALG_IMPORT
from ray.tune.syncer import get_node_to_storage_syncer, SyncConfig

from ludwig.hyperopt.syncer import RemoteSyncer
else:
from ray.ml import Checkpoint
from ray.tune.suggest import SEARCH_ALG_IMPORT
from ray.tune.syncer import get_cloud_sync_client


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -778,12 +778,14 @@ def run_experiment_trial(config, local_hyperopt_dict, checkpoint_dir=None):
)

if has_remote_protocol(output_directory):
run_experiment_trial = tune.durable(run_experiment_trial)
self.sync_config = tune.SyncConfig(sync_to_driver=False, upload_dir=output_directory)
if _ray_200:
self.sync_client = get_node_to_storage_syncer(SyncConfig(upload_dir=output_directory))
self.sync_client = RemoteSyncer()
self.sync_config = tune.SyncConfig(upload_dir=output_directory, syncer=self.sync_client)
else:
self.sync_client = get_cloud_sync_client(output_directory)
raise ValueError(
"Syncing to remote filesystems with hyperopt is not supported with ray<2.0, "
"please upgrade to ray>=2.0"
)
output_directory = None
elif self.kubernetes_namespace:
from ray.tune.integration.kubernetes import KubernetesSyncClient, NamespacedKubernetesSyncer
Expand Down
34 changes: 34 additions & 0 deletions ludwig/hyperopt/syncer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Any, Callable, Dict, List, Optional, Tuple

from ray.tune.syncer import _BackgroundSyncer

from ludwig.utils.data_utils import use_credentials
from ludwig.utils.fs_utils import delete, download, upload


class RemoteSyncer(_BackgroundSyncer):
def __init__(self, sync_period: float = 300.0, creds: Optional[Dict[str, Any]] = None):
super().__init__(sync_period=sync_period)
self.creds = creds

def _sync_up_command(self, local_path: str, uri: str, exclude: Optional[List] = None) -> Tuple[Callable, Dict]:
with use_credentials(self.creds):
return upload, dict(lpath=local_path, rpath=uri)

def _sync_down_command(self, uri: str, local_path: str) -> Tuple[Callable, Dict]:
with use_credentials(self.creds):
return download, dict(rpath=uri, lpath=local_path)

def _delete_command(self, uri: str) -> Tuple[Callable, Dict]:
with use_credentials(self.creds):
return delete, dict(url=uri, recursive=True)

def __reduce__(self):
"""We need this custom serialization because we can't pickle thread.lock objects that are used by the
use_credentials context manager.
https://docs.ray.io/en/latest/ray-core/objects/serialization.html#customized-serialization
"""
deserializer = RemoteSyncer
serialized_data = (self.sync_period, self.creds)
return deserializer, serialized_data
10 changes: 10 additions & 0 deletions ludwig/utils/fs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,16 @@ def delete(url, recursive=False):
return fs.delete(path, recursive=recursive)


def upload(lpath, rpath):
fs, path = get_fs_and_path(rpath)
pyarrow.fs.copy_files(lpath, path, destination_filesystem=pyarrow.fs.PyFileSystem(pyarrow.fs.FSSpecHandler(fs)))


def download(rpath, lpath):
fs, path = get_fs_and_path(rpath)
pyarrow.fs.copy_files(path, lpath, source_filesystem=pyarrow.fs.PyFileSystem(pyarrow.fs.FSSpecHandler(fs)))


def checksum(url):
fs, path = get_fs_and_path(url)
return fs.checksum(path)
Expand Down
62 changes: 29 additions & 33 deletions tests/integration_tests/test_hyperopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import contextlib
import json
import os.path
from typing import Any, Dict, Optional, Tuple, Union
import uuid
from typing import Any, Dict, Optional, Tuple

import pytest
import torch
Expand All @@ -39,23 +40,21 @@
from ludwig.globals import HYPEROPT_STATISTICS_FILE_NAME
from ludwig.hyperopt.results import HyperoptResults
from ludwig.hyperopt.run import hyperopt, update_hyperopt_params_with_defaults
from ludwig.utils import fs_utils
from ludwig.utils.data_utils import load_json
from ludwig.utils.defaults import merge_with_defaults
from tests.integration_tests.utils import category_feature, generate_data, text_feature
from tests.integration_tests.utils import category_feature, generate_data, private_param, remote_tmpdir, text_feature

try:
import ray
ray = pytest.importorskip("ray")

from ludwig.hyperopt.execution import get_build_hyperopt_executor
from ludwig.hyperopt.execution import get_build_hyperopt_executor # noqa

_ray113 = version.parse(ray.__version__) > version.parse("1.13")
_ray200 = version.parse(ray.__version__) >= version.parse("2.0")

except ImportError:
ray = None
_ray113 = None
pytestmark = pytest.mark.distributed


RANDOM_SEARCH_SIZE = 4
RANDOM_SEARCH_SIZE = 2

HYPEROPT_CONFIG = {
"parameters": {
Expand Down Expand Up @@ -165,18 +164,6 @@ def _setup_ludwig_config_with_shared_params(dataset_fp: str) -> Tuple[Dict, Any]
return config, rel_path, num_filters_search_space, embedding_size_search_space, reduce_input_search_space


def _get_trial_parameter_value(parameter_key: str, trial_row: str) -> Union[str, None]:
"""Returns the parameter value from the Ray trial row, which has slightly different column names depending on
the version of Ray. Returns None if the parameter key is not found.
TODO(#2176): There are different key name delimiters depending on Ray version. The delimiter in future versions of
Ray (> 1.13) will be '/' instead of '.' Simplify this as Ray is upgraded.
"""
if _ray113:
return trial_row[f"config/{parameter_key}"]
return trial_row[f"config.{parameter_key}"]


@contextlib.contextmanager
def ray_start(num_cpus: Optional[int] = None, num_gpus: Optional[int] = None):
res = ray.init(
Expand All @@ -198,7 +185,6 @@ def ray_cluster():
yield


@pytest.mark.distributed
@pytest.mark.parametrize("search_alg", SEARCH_ALGS_FOR_TESTING)
def test_hyperopt_search_alg(
search_alg, csv_filename, tmpdir, ray_cluster, validate_output_feature=False, validation_metric=None
Expand Down Expand Up @@ -249,7 +235,6 @@ def test_hyperopt_search_alg(
assert isinstance(path, str)


@pytest.mark.distributed
def test_hyperopt_executor_with_metric(csv_filename, tmpdir, ray_cluster):
test_hyperopt_search_alg(
"variant_generator",
Expand All @@ -261,7 +246,6 @@ def test_hyperopt_executor_with_metric(csv_filename, tmpdir, ray_cluster):
)


@pytest.mark.distributed
@pytest.mark.parametrize("scheduler", SCHEDULERS_FOR_TESTING)
def test_hyperopt_scheduler(
scheduler, csv_filename, tmpdir, ray_cluster, validate_output_feature=False, validation_metric=None
Expand Down Expand Up @@ -316,7 +300,6 @@ def test_hyperopt_scheduler(
assert isinstance(raytune_results, HyperoptResults)


@pytest.mark.distributed
@pytest.mark.parametrize("search_space", ["random", "grid"])
def test_hyperopt_run_hyperopt(csv_filename, search_space, tmpdir, ray_cluster):
input_features = [
Expand Down Expand Up @@ -370,14 +353,19 @@ def test_hyperopt_run_hyperopt(csv_filename, search_space, tmpdir, ray_cluster):
"goal": "minimize",
"output_feature": output_feature_name,
"validation_metrics": "loss",
"executor": {TYPE: "ray", "num_samples": 1 if search_space == "grid" else RANDOM_SEARCH_SIZE},
"executor": {
TYPE: "ray",
"num_samples": 1 if search_space == "grid" else RANDOM_SEARCH_SIZE,
"max_concurrent_trials": 1,
},
"search_alg": {TYPE: "variant_generator"},
}

# add hyperopt parameter space to the config
config[HYPEROPT] = hyperopt_configs

hyperopt_results = hyperopt(config, dataset=rel_path, output_directory=tmpdir, experiment_name="test_hyperopt")
experiment_name = f"test_hyperopt_{uuid.uuid4().hex}"
hyperopt_results = hyperopt(config, dataset=rel_path, output_directory=tmpdir, experiment_name=experiment_name)
if search_space == "random":
assert hyperopt_results.experiment_analysis.results_df.shape[0] == RANDOM_SEARCH_SIZE
else:
Expand All @@ -391,10 +379,21 @@ def test_hyperopt_run_hyperopt(csv_filename, search_space, tmpdir, ray_cluster):
assert isinstance(hyperopt_results, HyperoptResults)

# check for existence of the hyperopt statistics file
assert os.path.isfile(os.path.join(tmpdir, "test_hyperopt", HYPEROPT_STATISTICS_FILE_NAME))
assert fs_utils.path_exists(os.path.join(tmpdir, experiment_name, HYPEROPT_STATISTICS_FILE_NAME))


@pytest.mark.parametrize("fs_protocol,bucket", [private_param(("s3", "ludwig-tests"))], ids=["s3"])
def test_hyperopt_sync_remote(fs_protocol, bucket, csv_filename, ray_cluster):
with remote_tmpdir(fs_protocol, bucket) as tmpdir:
with pytest.raises(ValueError) if not _ray200 else contextlib.nullcontext():
test_hyperopt_run_hyperopt(
csv_filename,
"random",
tmpdir,
ray_cluster,
)


@pytest.mark.distributed
def test_hyperopt_with_feature_specific_parameters(csv_filename, tmpdir, ray_cluster):
input_features = [
text_feature(name="utterance", reduce_output="sum"),
Expand Down Expand Up @@ -446,7 +445,6 @@ def test_hyperopt_with_feature_specific_parameters(csv_filename, tmpdir, ray_clu
assert input_feature["encoder"]["embedding_size"] in embedding_size_search_space


@pytest.mark.distributed
def test_hyperopt_old_config(csv_filename, tmpdir, ray_cluster):
old_config = {
"ludwig_version": "0.4",
Expand Down Expand Up @@ -500,7 +498,6 @@ def test_hyperopt_old_config(csv_filename, tmpdir, ray_cluster):
hyperopt(old_config, dataset=rel_path, output_directory=tmpdir, experiment_name="test_hyperopt")


@pytest.mark.distributed
def test_hyperopt_nested_parameters(csv_filename, tmpdir, ray_cluster):
config = {
INPUT_FEATURES: [
Expand Down Expand Up @@ -591,7 +588,6 @@ def test_hyperopt_nested_parameters(csv_filename, tmpdir, ray_cluster):
assert trial_config[TRAINER]["learning_rate"] in {0.7, 0.42}


@pytest.mark.distributed
def test_hyperopt_grid_search_more_than_one_sample(csv_filename, tmpdir, ray_cluster):
input_features = [
text_feature(name="utterance", encoder={"reduce_output": "sum"}),
Expand Down
26 changes: 7 additions & 19 deletions tests/integration_tests/test_remote.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import contextlib
import os
import tempfile
import uuid

import pytest
import yaml
Expand All @@ -11,22 +8,13 @@
from ludwig.constants import TRAINER
from ludwig.globals import DESCRIPTION_FILE_NAME
from ludwig.utils import fs_utils
from tests.integration_tests.utils import category_feature, generate_data, private_param, sequence_feature


@contextlib.contextmanager
def remote_tmpdir(fs_protocol, bucket):
if bucket is None:
with tempfile.TemporaryDirectory() as tmpdir:
yield f"{fs_protocol}://{tmpdir}"
return

prefix = f"tmp_{uuid.uuid4().hex}"
tmpdir = f"{fs_protocol}://{bucket}/{prefix}"
try:
yield tmpdir
finally:
fs_utils.delete(tmpdir, recursive=True)
from tests.integration_tests.utils import (
category_feature,
generate_data,
private_param,
remote_tmpdir,
sequence_feature,
)


@pytest.mark.parametrize(
Expand Down
21 changes: 21 additions & 0 deletions tests/integration_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================

import contextlib
import logging
import multiprocessing
import os
Expand Down Expand Up @@ -40,6 +41,7 @@
from ludwig.experiment import experiment_cli
from ludwig.features.feature_utils import compute_feature_hash
from ludwig.trainers.trainer import Trainer
from ludwig.utils import fs_utils
from ludwig.utils.data_utils import read_csv, replace_file_extension

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -872,3 +874,22 @@ def filter(stats):
finally:
# Remove results/intermediate data saved to disk
shutil.rmtree(output_dir, ignore_errors=True)


@contextlib.contextmanager
def remote_tmpdir(fs_protocol, bucket):
if bucket is None:
with tempfile.TemporaryDirectory() as tmpdir:
yield f"{fs_protocol}://{tmpdir}"
return

prefix = f"tmp_{uuid.uuid4().hex}"
tmpdir = f"{fs_protocol}://{bucket}/{prefix}"
try:
yield tmpdir
finally:
try:
fs_utils.delete(tmpdir, recursive=True)
except FileNotFoundError as e:
logging.info(f"failed to delete remote tempdir, does not exist: {str(e)}")
pass

0 comments on commit d8a0d8f

Please sign in to comment.