diff --git a/dev-requirements.txt b/dev-requirements.txt index 1b0ae50a3..437a86ea9 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,3 +1,4 @@ +accelerate aiobotocore ax-platform[mysql]==0.2.3 black==23.3.0 @@ -20,7 +21,7 @@ protobuf==3.20.3 pyre-extensions pyre-check pytest -pytorch-lightning==1.5.10 +pytorch-lightning==2.2.0 torch-model-archiver>=0.4.2 torch>=1.10.0 torchmetrics<0.11.0 diff --git a/torchx/components/dist.py b/torchx/components/dist.py index 3ffc08339..ad2e376ed 100644 --- a/torchx/components/dist.py +++ b/torchx/components/dist.py @@ -350,3 +350,118 @@ def parse_nnodes(j: str) -> Tuple[int, int, int, str]: f"Invalid format for -j, usage example: 1:2x4 or 1x4 or 4. Given: {j}" ) return int(min_nnodes), int(max_nnodes), int(nproc_per_node), nnodes_rep + +def accelerate( + *script_args: str, + script: Optional[str] = None, + image: str = torchx.IMAGE, + name: str = "/", + h: Optional[str] = None, + cpu: int = 2, + gpu: int = 0, + memMB: int = 1024, + j: str = "1x2", + env: Optional[Dict[str, str]] = None, + max_retries: int = 0, + main_process_port: int = 29500, + accelerate_args: Optional[List[str]] = None, + mounts: Optional[List[str]] = None, + debug: bool = False, +) -> specs.AppDef: + """ + A component that uses HuggingFace accelerate to launch the job. + + Args: + script_args: arguments to the main module + script: script or binary to run within the image + image: image (e.g. docker) + name: job name override in the following format: ``{experimentname}/{runname}`` or ``{experimentname}/`` or ``/{runname}`` or ``{runname}``. + Uses the script or module name if ``{runname}`` not specified. + cpu: number of cpus per replica + gpu: number of gpus per replica + memMB: cpu memory in MB per replica + h: a registered named resource (if specified takes precedence over cpu, gpu, memMB) + j: [{min_nnodes}:]{nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus + env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3) + max_retries: the number of scheduler retries allowed + main_process_port: the port on rank0's host to use for coordinating the workers. + Only takes effect when running multi-node. When running single node, this parameter + is ignored and a random free port is chosen. + mounts: mounts to mount into the worker environment/container (ex. type=,src=/host,dst=/job[,readonly]). + See scheduler documentation for more info. + debug: whether to run with preset debug flags enabled + """ + + # nnodes: number of nodes or minimum nodes for elastic launch + # max_nnodes: maximum number of nodes for elastic launch + # nproc_per_node: number of processes on each node + min_nnodes, max_nnodes, nproc_per_node, nnodes_rep = parse_nnodes(j) + + assert min_nnodes == max_nnodes, "accelerate component doesn't support elasticity" + + if max_nnodes == 1: + # using port 0 makes elastic chose a free random port which is ok + # for single-node jobs since all workers run under a single agent + # When nnodes is 0 and max_nnodes is 1, it's stil a single node job + # but pending until the resources become available + main_process_ip = "localhost" + else: + # for multi-node, rely on the rank0_env environment variable set by + # the schedulers (see scheduler implementation for the actual env var this maps to) + # some schedulers (e.g. aws batch) make the rank0's ip-addr available on all BUT on rank0 + # so default to "localhost" if the env var is not set or is empty + # rdzv_endpoint bash resolves to something to the effect of + # ${TORCHX_RANK0_HOST:=localhost}:29500 + # use $$ in the prefix to escape the '$' literal (rather than a string Template substitution argument) + main_process_ip = _noquote(f"$${{{macros.rank0_env}:=localhost}}") + + argname = StructuredNameArgument.parse_from( + name=name, + m=None, + script=script, + ) + + if env is None: + env = {} + + if debug: + env.update(_TORCH_DEBUG_FLAGS) + + env["TORCHX_TRACKING_EXPERIMENT_NAME"] = argname.experiment_name + + cmd = [ + "accelerate", + "launch", + f"--main_process_ip", + main_process_ip, + f"--main_process_port={main_process_port}", + f"--num_machines={max_nnodes}", + f"--num_processes={nproc_per_node*max_nnodes}", + f"--machine_rank={macros.replica_id}", + f"--max_restarts={max_retries}", + ] + if accelerate_args is not None: + cmd += accelerate_args + cmd += [script] + cmd += script_args + + return specs.AppDef( + name=argname.run_name, + roles=[ + specs.Role( + name=get_role_name(script, None), + image=image, + min_replicas=min_nnodes, + entrypoint="bash", + num_replicas=int(max_nnodes), + resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h), + args=["-c", _args_join(cmd)], + env=env, + port_map={ + "accelerate": main_process_port, + }, + max_retries=max_retries, + mounts=specs.parse_mounts(mounts) if mounts else [], + ) + ], + ) diff --git a/torchx/components/integration_tests/component_provider.py b/torchx/components/integration_tests/component_provider.py index 442e57a3a..020967c2c 100644 --- a/torchx/components/integration_tests/component_provider.py +++ b/torchx/components/integration_tests/component_provider.py @@ -49,6 +49,22 @@ def get_app_def(self) -> AppDef: ) +class AccelerateComponentProvider(ComponentProvider): + def get_app_def(self) -> AppDef: + return dist_components.accelerate( + script="torchx/examples/apps/compute_world_size/main.py", + name="accelerate-compute", + image=self._image, + cpu=1, + j="2x2", + max_retries=3, + main_process_port=19501, + env={ + "LOGLEVEL": "INFO", + }, + ) + + class ServeComponentProvider(ComponentProvider): # TODO(aivanou): Remove dryrun and test e2e serve component+app def get_app_def(self) -> AppDef: diff --git a/torchx/components/test/dist_test.py b/torchx/components/test/dist_test.py index 296742d3b..44880922f 100644 --- a/torchx/components/test/dist_test.py +++ b/torchx/components/test/dist_test.py @@ -5,7 +5,10 @@ # LICENSE file in the root directory of this source tree. from torchx.components.component_test_base import ComponentTestCase -from torchx.components.dist import _TORCH_DEBUG_FLAGS, ddp, parse_nnodes, spmd +from torchx.components.dist import ( + _TORCH_DEBUG_FLAGS, ddp, parse_nnodes, spmd, + accelerate, +) class DDPTest(ComponentTestCase): @@ -148,3 +151,37 @@ def test_spmd_call_by_module_or_script_with_run_name(self) -> None: "default-experiment", appdef.roles[0].env["TORCHX_TRACKING_EXPERIMENT_NAME"], ) + +class AccelerateTest(ComponentTestCase): + def test_ddp(self) -> None: + import torchx.components.dist as dist + + self.validate(dist, "accelerate") + + def test_basic(self) -> None: + app = accelerate( + "--script_arg", + script="foo.py", + j="2x2", + accelerate_args=["--accelerate_arg"], + env={"a": "b"} + ) + self.assertEqual(len(app.roles), 1) + role = app.roles[0] + self.assertEqual(role.num_replicas, 2) + args = " ".join(role.args) + self.assertIn("--script_arg", args) + self.assertIn("--accelerate_arg", args) + self.assertIn("a", role.env) + + def test_mounts(self) -> None: + app = accelerate( + script="foo.py", mounts=["type=bind", "src=/dst", "dst=/dst", "readonly"] + ) + self.assertEqual(len(app.roles[0].mounts), 1) + + def test_debug(self) -> None: + app = accelerate(script="foo.py", debug=True) + env = app.roles[0].env + for k, v in _TORCH_DEBUG_FLAGS.items(): + self.assertEqual(env[k], v) diff --git a/torchx/examples/apps/lightning/model.py b/torchx/examples/apps/lightning/model.py index 973caf838..aa57c55f6 100644 --- a/torchx/examples/apps/lightning/model.py +++ b/torchx/examples/apps/lightning/model.py @@ -47,8 +47,8 @@ def __init__( m.fc.out_features = 200 self.model: ResNet = m - self.train_acc = Accuracy() - self.val_acc = Accuracy() + self.train_acc = Accuracy(task="multiclass", num_classes=1000) + self.val_acc = Accuracy(task="multiclass", num_classes=1000) # pyre-fixme[14] def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/torchx/examples/apps/lightning/profiler.py b/torchx/examples/apps/lightning/profiler.py index bca66531b..3053ea42f 100644 --- a/torchx/examples/apps/lightning/profiler.py +++ b/torchx/examples/apps/lightning/profiler.py @@ -17,18 +17,18 @@ import time from typing import Dict -from pytorch_lightning.loggers.base import LightningLoggerBase -from pytorch_lightning.profiler.base import BaseProfiler +from pytorch_lightning.loggers.logger import Logger +from pytorch_lightning.profilers import Profiler -class SimpleLoggingProfiler(BaseProfiler): +class SimpleLoggingProfiler(Profiler): """ This profiler records the duration of actions (in seconds) and reports the mean duration of each action to the specified logger. Reported metrics are in the format `duration_`. """ - def __init__(self, logger: LightningLoggerBase) -> None: + def __init__(self, logger: Logger) -> None: super().__init__() self.current_actions: Dict[str, float] = {}