Skip to content

Commit

Permalink
Additional fixes (#1)
Browse files Browse the repository at this point in the history
* comment out pythia revision

* suppress tango.cli error

* unified distribute function; filter for stream logs

* updates for valid configs

* update to torchrunx 0.2.1
  • Loading branch information
apoorvkh authored Oct 31, 2024
1 parent 919c675 commit 1fd6a8e
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 108 deletions.
20 changes: 11 additions & 9 deletions experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
from tqdm import tqdm

from experiments.__tango__ import TangoStringHash, step, tango_executor, tango_settings, tango_workspace
from experiments.__torchrunx__ import distribute
from experiments.config import GpuT

__all__ = ["SlurmJob", "Experiment", "Sweep", "distribute", "TangoStringHash", "step"]

@dataclass(unsafe_hash=True) # for batching jobs
class SlurmJob:
Expand Down Expand Up @@ -115,12 +117,15 @@ def _execute_step_graph(self) -> None:
# if CLI was already initialized
mp.set_start_method(None, force=True)

with tango_cli(tango_settings):
tango.cli.execute_step_graph(
step_graph=self.step_graph,
workspace=tango_workspace,
executor=tango_executor,
)
try:
with tango_cli(tango_settings):
tango.cli.execute_step_graph(
step_graph=self.step_graph,
workspace=tango_workspace,
executor=tango_executor,
)
except tango.common.exceptions.CliRunError:
pass

def is_cached(self) -> bool:
for s in self.step_dict.values():
Expand Down Expand Up @@ -289,6 +294,3 @@ def run(
@classmethod
def cli(cls) -> None:
tyro.cli(cls.run)


__all__ = ["SlurmJob", "Experiment", "Sweep", "TangoStringHash", "step"]
58 changes: 58 additions & 0 deletions experiments/__torchrunx__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import logging
import os
from datetime import datetime
from pathlib import Path
from typing import Any, Callable

import torchrunx


def build_logging_handlers(hostnames):
log_dir = os.environ["TORCHRUNX_LOG_DIR"]
Path(log_dir).mkdir(parents=True, exist_ok=True)

timestamp = datetime.now().isoformat(timespec="seconds")
file_paths = [f"{log_dir}/{timestamp}-{hostname}.log" for hostname in hostnames]

def _handler_builder() -> list[logging.Handler]:
handlers = []

stream_handler = torchrunx.stream_handler(hostname=hostnames[0], local_rank=0)
stream_handler.addFilter(logging.Filter(name="academic-pretraining"))
handlers.append(stream_handler)

for hostname, file_path in zip(hostnames, file_paths):
handlers += [
torchrunx.file_handler(hostname=hostname, local_rank=None, file_path=file_path),
torchrunx.file_handler(hostname=hostname, local_rank=0, file_path=file_path),
]
return handlers

return _handler_builder, file_paths


def distribute(
func: Callable,
func_args: tuple[Any] | None = None,
func_kwargs: dict[str, Any] | None = None,
hostnames: list[str] | None = None,
workers_per_host: int | None = None,
) -> Any:
if hostnames is None:
hostnames = torchrunx.utils.environment.auto_hosts()
if workers_per_host is None:
workers_per_host = torchrunx.utils.environment.auto_workers()

log_handlers_builder, log_files = build_logging_handlers(hostnames)

print(f"Logging results of \"{func.__name__}\" to:")
for file_path in log_files:
print(f" - {file_path}")

return torchrunx.launch(
func=func,
func_kwargs=func_kwargs,
hostnames=hostnames,
workers_per_host=workers_per_host,
log_handlers_builder=log_handlers_builder,
).rank(0)
57 changes: 24 additions & 33 deletions experiments/training_time_empirical.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import dataclasses
import math
import os
import tempfile
from dataclasses import dataclass
from typing import Any, Sequence, TypedDict
from typing import Any, TypedDict

import torch
import torchrunx
from src.benchmarking.max_batch_size import find_max_mbs_pow2
from src.benchmarking.step_time import estimate_step_time
from src.benchmarking.utils import ManualTrainer
from tango import Step

from experiments import Experiment, SlurmJob, step
from experiments import Experiment, SlurmJob, distribute, step
from experiments.config import TrainingConfig


Expand Down Expand Up @@ -52,13 +50,11 @@ def find_largest_batch_size_worker(config: TrainingConfig, limit: int):

@step(cacheable=True, version="001")
def find_largest_batch_size(config: TrainingConfig, limit: int) -> int:
return torchrunx.launch(
return distribute(
func=find_largest_batch_size_worker,
func_kwargs={"config": config, "limit": limit},
hostnames=torchrunx.slurm_hosts(),
workers_per_host=config.gpus_per_node,
log_dir=os.environ["TORCHRUNX_LOG_DIR"],
)[0]
)


class BenchmarkingResults(TypedDict):
Expand Down Expand Up @@ -98,7 +94,7 @@ def benchmark_step_time(

while micro_batch_size > 0:
try:
benchmark_results = torchrunx.launch(
benchmark_results = distribute(
func=benchmark_step_time_worker,
func_kwargs=dict(
config=config,
Expand All @@ -107,14 +103,12 @@ def benchmark_step_time(
target_micro_batch_size=target_micro_batch_size,
num_benchmarking_steps=num_benchmarking_steps,
),
hostnames=torchrunx.slurm_hosts(),
workers_per_host=config.gpus_per_node,
log_dir=os.environ["TORCHRUNX_LOG_DIR"],
)[0]
)
except RuntimeError:
if config.free_lunch:
print("Possible time-out during compile, trying again without compiling")
benchmark_results = torchrunx.launch(
print("Possible time-out during compile, trying again without compiling...")
benchmark_results = distribute(
func=benchmark_step_time_worker,
func_kwargs=dict(
config=config,
Expand All @@ -123,10 +117,8 @@ def benchmark_step_time(
target_micro_batch_size=target_micro_batch_size,
num_benchmarking_steps=num_benchmarking_steps,
),
hostnames=torchrunx.slurm_hosts(),
workers_per_host=config.gpus_per_node,
log_dir=os.environ["TORCHRUNX_LOG_DIR"],
)[0]
)
else:
raise

Expand All @@ -146,6 +138,9 @@ def compute_training_days(benchmarking_results: BenchmarkingResults | None, num_
return (num_steps * benchmarking_results["step_time"]) / (24 * 60 * 60)


## Experiment


@dataclass
class TrainingTimeEmpirical(Experiment):
config: TrainingConfig
Expand All @@ -168,23 +163,23 @@ def is_valid(self) -> bool:
[
self.benchmarking_steps <= 0,
self.trial < 0,
# target batch size should be power of 2
not math.log2(self.model_class.batch_size).is_integer(),
# target batch size should be evenly divisible by total GPUs
# model batch size should be evenly divisible by total GPUs
self.model_class.batch_size % (self.config.num_nodes * self.config.gpus_per_node) > 0,
# batch size per gpu should be power of 2
not math.log2(
self.model_class.batch_size // (self.config.num_nodes * self.config.gpus_per_node)
).is_integer(),
# if activation checkpointing is enabled, model should support it
(not self.model_class.supports_activation_checkpointing) and self.config.activation_checkpointing,
self.config.activation_checkpointing and (not self.model_class.supports_activation_checkpointing),
# data types for ampere or newer GPUs
(not self.config.ampere_or_newer_gpu() and (self.training_class.tf32 or self.training_class.bf16)),
self.model_class.mixed_precision == "bf16" and not self.config.ampere_or_newer_gpu(),
# don't shard for a single GPU (no-op)
(
self.config.num_nodes == 1
and self.config.gpus_per_node == 1
and self.config.sharding != ""
and not self.config.offloading
),
self.config.num_nodes == 1
and self.config.gpus_per_node == 1
and self.config.sharding != ""
and not self.config.offloading,
# offloading requires sharding
(self.config.sharding == "" and self.config.offloading),
(self.config.offloading and self.config.sharding == ""),
]
):
return False
Expand Down Expand Up @@ -222,10 +217,6 @@ def slurm_job(self) -> SlurmJob | None:
gpu_type=self.config.gpu_type,
)

@property
def dependencies(self) -> Sequence[Experiment]:
return []

def results(self):
return {
"max_micro_batch_size": self.step_result("max_micro_batch_size"),
Expand Down
Loading

0 comments on commit 1fd6a8e

Please sign in to comment.