Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable saving hyperopt checkpoints with multi-node clusters #2386

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def train(

if model_resume_path is None:
if self.backend.is_coordinator():
output_directory = get_output_directory(output_directory, experiment_name, model_name)
output_directory = get_output_directory(output_directory, experiment_name, model_name, self.backend)
else:
output_directory = None

Expand Down Expand Up @@ -551,6 +551,7 @@ def on_epoch_end(self, trainer, progress_tracker, save_path):
model=self.model,
config=self.config,
config_fp=self.config_fp,
save_path=model_dir,
)

try:
Expand Down Expand Up @@ -1424,7 +1425,7 @@ def load_weights(
self.backend.sync_model(self.model)

def save(self, save_path: str) -> None:
"""This function allows to save models on disk.
"""This function allows saving models on disk.

# Inputs

Expand Down
37 changes: 35 additions & 2 deletions ludwig/automl/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,37 @@ def __init__(self, experiment_analysis: ExperimentAnalysis):
def experiment_analysis(self):
return self._experiment_analysis

@property
def best_trial(self):
return self._experiment_analysis.best_trial

@property
def best_trial_id(self) -> str:
return self._experiment_analysis.best_trial.trial_id

@property
def best_checkpoint(self) -> str:
return self._experiment_analysis.best_checkpoint

@property
def best_checkpoint_local_path(self) -> str:
checkpoint = self.best_checkpoint
if checkpoint is None:
logging.warning("No best checkpoint found")
return None
return self._experiment_analysis.best_checkpoint.local_path

@property
def best_checkpoint_cloud_path(self) -> str:
checkpoint = self.best_checkpoint
if checkpoint is None:
logging.warning("No best checkpoint found")
return None
return self._experiment_analysis.best_checkpoint.cloud_path

@property
def best_model(self) -> Optional[LudwigModel]:
checkpoint = self._experiment_analysis.best_checkpoint
checkpoint = self.best_checkpoint
if checkpoint is None:
logger.warning("No best model found")
return None
Expand Down Expand Up @@ -206,6 +230,7 @@ def create_auto_config(
def train_with_config(
dataset: Union[str, pd.DataFrame, dd.core.DataFrame],
config: dict,
backend: Optional[Backend] = None,
output_directory: str = OUTPUT_DIR,
random_seed: int = default_random_seed,
**kwargs,
Expand All @@ -231,7 +256,13 @@ def train_with_config(

model_type = get_model_type(config)
hyperopt_results = _train(
config, dataset, output_directory=output_directory, model_name=model_type, random_seed=random_seed, **kwargs
config,
dataset,
backend=backend,
output_directory=output_directory,
model_name=model_type,
random_seed=random_seed,
**kwargs,
)
# catch edge case where metric_score is nan
# TODO (ASN): Decide how we want to proceed if at least one trial has
Expand Down Expand Up @@ -325,10 +356,12 @@ def _train(
output_directory: str,
model_name: str,
random_seed: int,
backend: Optional[Backend] = None,
**kwargs,
):
hyperopt_results = hyperopt(
config,
backend=backend,
dataset=dataset,
output_directory=output_directory,
model_name=model_name,
Expand Down
9 changes: 8 additions & 1 deletion ludwig/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import psutil
import torch

from ludwig.data.cache.manager import CacheManager
from ludwig.data.cache.manager import CacheManager, HyperoptSyncManager
from ludwig.data.dataframe.pandas import PANDAS
from ludwig.data.dataset.base import DatasetManager
from ludwig.data.dataset.pandas import PandasDatasetManager
Expand All @@ -43,9 +43,12 @@ def __init__(
dataset_manager: DatasetManager,
cache_dir: Optional[str] = None,
cache_credentials: Optional[Union[str, dict]] = None,
hyperopt_sync_dir: Optional[str] = None,
hyperopt_sync_credentials: Optional[Union[str, dict]] = None,
):
self._dataset_manager = dataset_manager
self._cache_manager = CacheManager(self._dataset_manager, cache_dir, cache_credentials)
self._hyperopt_sync_manager = HyperoptSyncManager(hyperopt_sync_dir, hyperopt_sync_credentials)

@property
def cache(self):
Expand All @@ -55,6 +58,10 @@ def cache(self):
def dataset_manager(self):
return self._dataset_manager

@property
def hyperopt_sync_manager(self):
return self._hyperopt_sync_manager

@abstractmethod
def initialize(self):
raise NotImplementedError()
Expand Down
19 changes: 17 additions & 2 deletions ludwig/backend/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,8 +816,23 @@ def _prepare_batch(self, batch: pd.DataFrame) -> Dict[str, np.ndarray]:
class RayBackend(RemoteTrainingMixin, Backend):
BACKEND_TYPE = "ray"

def __init__(self, processor=None, trainer=None, loader=None, use_legacy=False, preprocessor_kwargs=None, **kwargs):
super().__init__(dataset_manager=RayDatasetManager(self), **kwargs)
def __init__(
self,
processor=None,
trainer=None,
loader=None,
use_legacy=False,
preprocessor_kwargs=None,
hyperopt_sync_dir: Optional[str] = None,
hyperopt_sync_credentials: Optional[Union[str, dict]] = None,
**kwargs,
):
super().__init__(
dataset_manager=RayDatasetManager(self),
hyperopt_sync_dir=hyperopt_sync_dir,
hyperopt_sync_credentials=hyperopt_sync_credentials,
**kwargs,
)
self._preprocessor_kwargs = preprocessor_kwargs or {}
self._df_engine = _get_df_engine(processor)
self._horovod_kwargs = trainer or {}
Expand Down
1 change: 1 addition & 0 deletions ludwig/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def on_train_start(
model,
config: Dict[str, Any],
config_fp: Union[str, None],
save_path: Union[str, None] = None,
):
"""Called after creation of trainer, before the start of training.

Expand Down
2 changes: 2 additions & 0 deletions ludwig/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@
S3 = "s3"
CACHE = "cache"

HYPEROPT_LOCAL_DIR = "~/ray_results"

# If `use_torch_profiler=True` in LudwigProfiler, LUDWIG_TAG is prepended to the specified experiment tag
# (LudwigProfiler(tag="...", ..)). This edited tag is passed in to `torch.profiler.record_function` so we can
# retrieve torch ops for the tagged code blocks/functions.
Expand Down
20 changes: 20 additions & 0 deletions ludwig/data/cache/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,23 @@ def data_format(self) -> str:
@property
def credentials(self) -> Optional[dict]:
return self._cache_credentials


class HyperoptSyncManager:
def __init__(
self,
hyperopt_sync_dir: Optional[str] = None,
hyperopt_sync_credentials: Optional[Union[str, dict]] = None,
):
self._hyperopt_sync_dir = hyperopt_sync_dir
if isinstance(hyperopt_sync_credentials, str):
hyperopt_sync_credentials = data_utils.load_json(hyperopt_sync_credentials)
self._hyperopt_sync_credentials = hyperopt_sync_credentials

@property
def credentials(self) -> Union[dict, None]:
return self._hyperopt_sync_credentials

@property
def sync_dir(self) -> Union[str, None]:
return self._hyperopt_sync_dir
75 changes: 55 additions & 20 deletions ludwig/hyperopt/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@
from ludwig.backend import initialize_backend, RAY
from ludwig.backend.ray import initialize_ray
from ludwig.callbacks import Callback
from ludwig.constants import MAXIMIZE, TEST, TRAINER, TRAINING, TYPE, VALIDATION
from ludwig.constants import HYPEROPT_LOCAL_DIR, MAXIMIZE, TEST, TRAINER, TRAINING, TYPE, VALIDATION
from ludwig.globals import MODEL_HYPERPARAMETERS_FILE_NAME, TRAIN_SET_METADATA_FILE_NAME
from ludwig.hyperopt.results import HyperoptResults, TrialResults
from ludwig.hyperopt.search_algos import get_search_algorithm
from ludwig.hyperopt.utils import load_json_values, substitute_parameters
from ludwig.modules.metric_modules import get_best_function
from ludwig.utils import metric_utils
from ludwig.utils.data_utils import hash_dict, NumpyEncoder
from ludwig.utils.data_utils import hash_dict, NumpyEncoder, save_json, use_credentials
from ludwig.utils.defaults import default_random_seed, merge_with_defaults
from ludwig.utils.fs_utils import has_remote_protocol, safe_move_file
from ludwig.utils.misc_utils import get_from_registry
Expand All @@ -44,7 +45,8 @@
if _ray_200:
from ray.air import Checkpoint
from ray.tune.search import SEARCH_ALG_IMPORT
from ray.tune.syncer import get_node_to_storage_syncer, SyncConfig

# from ray.tune.syncer import get_node_to_storage_syncer, SyncConfig
else:
from ray.ml import Checkpoint
from ray.tune.suggest import SEARCH_ALG_IMPORT
Expand Down Expand Up @@ -165,6 +167,8 @@ def __init__(
self.max_concurrent_trials = max_concurrent_trials
self.sync_config = None
self.sync_client = None
self.sync_function_template = kwargs.get("sync_function_template", None)
self.delete_function_template = kwargs.get("delete_function_template", None)
# Head node is the node to which all checkpoints are synced if running on a K8s cluster.
self.head_node_ip = ray.util.get_node_ip_address()

Expand Down Expand Up @@ -303,26 +307,27 @@ def _cpu_resources_per_trial_non_none(self):
def _gpu_resources_per_trial_non_none(self):
return self.gpu_resources_per_trial if self.gpu_resources_per_trial is not None else 0

def _get_remote_checkpoint_dir(self, trial_dir: Path) -> Optional[Union[str, Tuple[str, str]]]:
def _get_remote_checkpoint_dir(
self, trial_dir: Path, use_tmp: bool = False
) -> Optional[Union[str, Tuple[str, str]]]:
"""Get the path to remote checkpoint directory."""
if self.sync_config is None:
return None

if self.sync_config.upload_dir is not None:
# Cloud storage sync config
remote_checkpoint_dir = os.path.join(
self.sync_config.upload_dir, *_get_relative_checkpoints_dir_parts(trial_dir)
)
return remote_checkpoint_dir
if use_tmp:
return os.path.join(self.sync_config.upload_dir, "tmp", *_get_relative_checkpoints_dir_parts(trial_dir))
return os.path.join(self.sync_config.upload_dir, *_get_relative_checkpoints_dir_parts(trial_dir))
elif self.kubernetes_namespace is not None:
# Kubernetes sync config. Returns driver node name and path.
# When running on kubernetes, each trial is rsynced to the node running the main process.
node_name = self._get_kubernetes_node_address_by_ip()(self.head_node_ip)
return (node_name, trial_dir)
else:
logger.warning(
"Checkpoint syncing disabled as syncing is only supported to remote cloud storage or on Kubernetes "
"clusters is supported. To use syncing, set the kubernetes_namespace in the config or use a cloud URI "
"Checkpoint syncing disabled as it is only supported to remote cloud storage or on Kubernetes "
"clusters. To use syncing, set the kubernetes_namespace in the config or use a cloud URI "
"as the output directory."
)
return None
Expand Down Expand Up @@ -356,6 +361,8 @@ def _remove_partial_checkpoints(trial_path: str):
@contextlib.contextmanager
def _get_best_model_path(self, trial_path: str, analysis: ExperimentAnalysis) -> str:
remote_checkpoint_dir = self._get_remote_checkpoint_dir(Path(trial_path))
# If remote checkpoint dir, sync down artifacts from remote to local
print(f"[_get_best_model_path] syncing down artifacts from {remote_checkpoint_dir} to {trial_path}")
if remote_checkpoint_dir is not None:
self.sync_client.sync_down(remote_checkpoint_dir, trial_path)
self.sync_client.wait_or_retry()
Expand Down Expand Up @@ -452,7 +459,7 @@ def _run_experiment(
modified_config = merge_with_defaults(modified_config)

hyperopt_dict["config"] = modified_config
hyperopt_dict["experiment_name "] = f'{hyperopt_dict["experiment_name"]}_{trial_id}'
hyperopt_dict["experiment_name"] = f'{hyperopt_dict["experiment_name"]}_{trial_id}'
hyperopt_dict["output_directory"] = str(trial_dir)

tune_executor = self
Expand Down Expand Up @@ -488,8 +495,8 @@ def __init__(self):
self.resume_ckpt_ref = None

def _get_remote_checkpoint_dir(self) -> Optional[Union[str, Tuple[str, str]]]:
# sync client has to be recreated to avoid issues with serialization
return tune_executor._get_remote_checkpoint_dir(trial_dir)
# Sync client has to be recreated to avoid issues with serialization
return tune_executor._get_remote_checkpoint_dir(trial_dir, use_tmp=True)

def _checkpoint_progress(self, trainer, progress_tracker, save_path) -> None:
"""Checkpoints the progress tracker."""
Expand Down Expand Up @@ -778,13 +785,31 @@ def run_experiment_trial(config, local_hyperopt_dict, checkpoint_dir=None):
)

if has_remote_protocol(output_directory):
run_experiment_trial = tune.durable(run_experiment_trial)
self.sync_config = tune.SyncConfig(sync_to_driver=False, upload_dir=output_directory)
if not _ray_200:
run_experiment_trial = tune.durable(run_experiment_trial)

# Build Sync Client
if _ray_200:
self.sync_client = get_node_to_storage_syncer(SyncConfig(upload_dir=output_directory))
from ludwig.hyperopt.syncer import RemoteSyncer

# self.sync_client = get_node_to_storage_syncer(SyncConfig(upload_dir=output_directory))
self.sync_client = RemoteSyncer(backend=backend)
# elif self.sync_function_template:
# self.sync_client = CommandBasedClient(
# sync_up_template=self.sync_function_template,
# sync_down_template=self.sync_function_template,
# delete_template=self.delete_function_template, # No errors if this is None
# )
else:
self.sync_client = get_cloud_sync_client(output_directory)
output_directory = None

# Build Sync Config
# if self.sync_function_template:
if _ray_200:
self.sync_config = tune.SyncConfig(upload_dir=output_directory, syncer=self.sync_client)
else:
self.sync_config = tune.SyncConfig(sync_to_driver=False, upload_dir=output_directory)

elif self.kubernetes_namespace:
from ray.tune.integration.kubernetes import KubernetesSyncClient, NamespacedKubernetesSyncer

Expand All @@ -803,6 +828,7 @@ def _register(name, trainable):
# otherwise will start a new experiment:
# https://docs.ray.io/en/latest/tune/tutorials/tune-stopping.html
should_resume = "AUTO" if resume is None else resume
checkpoint_score_attr = mode + "-" + metric

try:
analysis = tune.run(
Expand All @@ -816,11 +842,12 @@ def _register(name, trainable):
search_alg=search_alg,
num_samples=self.num_samples,
keep_checkpoints_num=1,
checkpoint_score_attr=checkpoint_score_attr,
max_failures=1, # retry a trial failure once
resources_per_trial=resources_per_trial,
time_budget_s=self.time_budget_s,
sync_config=self.sync_config,
local_dir=output_directory,
local_dir=HYPEROPT_LOCAL_DIR,
metric=metric,
mode=mode,
trial_name_creator=lambda trial: f"trial_{trial.trial_id}",
Expand Down Expand Up @@ -858,6 +885,7 @@ def _register(name, trainable):
if validation_set is not None and validation_set.size > 0:
trial_path = trial["trial_dir"]
with self._get_best_model_path(trial_path, analysis) as best_model_path:
print(f"Best model path: {best_model_path}")
if best_model_path is not None:
self._evaluate_best_model(
trial,
Expand All @@ -877,13 +905,20 @@ def _register(name, trainable):
else:
logger.warning("Skipping evaluation as no model checkpoints were available")
else:
logger.warning("Skipping evaluation as no validation set was provided")
logger.warning("Skipping evaluation as no validation set was provided.")

ordered_trials = [TrialResults.from_dict(load_json_values(kwargs)) for kwargs in temp_ordered_trials]
else:
logger.warning("No trials reported results; check if time budget lower than epoch latency")
logger.warning("No trials reported results; check if time budget is lower than epoch latency.")
ordered_trials = []

# Remove temporary trials directory if it was created during syncing through RayTuneReportCallback
if has_remote_protocol(output_directory):
with use_credentials(backend.hyperopt_sync_manager.credentials):
tmp_remote_dir_path = os.path.join(backend.hyperopt_sync_manager.sync_dir, "tmp", experiment_name)
if path_exists(tmp_remote_dir_path):
delete(tmp_remote_dir_path, recursive=True)

return HyperoptResults(ordered_trials=ordered_trials, experiment_analysis=analysis)


Expand Down
Loading