diff --git a/src/datatrove/executor/slurm.py b/src/datatrove/executor/slurm.py index 91c86b71..3e74a662 100644 --- a/src/datatrove/executor/slurm.py +++ b/src/datatrove/executor/slurm.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import math import os import random import signal @@ -79,7 +80,7 @@ class SlurmPipelineExecutor(PipelineExecutor): mail_type: see https://slurm.schedmd.com/sbatch.html. Common values are (NONE, BEGIN, END, FAIL, REQUEUE, ALL) mail_user: email address to send notifications to requeue: requeue the job if it fails - + tasks_per_job: each slurm job in the job array will run these many datatrove tasks. This reduces the total nb of slurm jobs launched. """ def __init__( @@ -111,6 +112,7 @@ def __init__( mail_type: str = "ALL", mail_user: str = None, requeue: bool = True, + tasks_per_job: int = 1, ): super().__init__(pipeline, logging_dir, skip_completed) self.tasks = tasks @@ -118,6 +120,7 @@ def __init__( self.partition = partition self.cpus_per_task = cpus_per_task self.mem_per_cpu_gb = mem_per_cpu_gb + self.tasks_per_job = tasks_per_job self.time = time self.job_name = job_name self.qos = qos @@ -160,18 +163,23 @@ def run(self): slurm_rank = int(os.environ["SLURM_ARRAY_TASK_ID"]) + self.max_array_size * int( os.environ.get("RUN_OFFSET", 0) ) + ranks_to_run_range = (slurm_rank * self.tasks_per_job, (slurm_rank + 1) * self.tasks_per_job) with self.logging_dir.open("ranks_to_run.json", "r") as ranks_to_run_file: all_ranks = json.load(ranks_to_run_file) - if slurm_rank >= len(all_ranks): + if ranks_to_run_range[0] >= len(all_ranks): return - rank = all_ranks[slurm_rank] for ss in self.requeue_signals or []: signal.signal(signal.Signals[ss], requeue_handler) - if self.randomize_start: - time.sleep(random.randint(0, 60 * 3)) - self._run_for_rank(rank) + for rank_to_run in range(*ranks_to_run_range): + if rank_to_run >= len(all_ranks): + break + rank = all_ranks[rank_to_run] + + if self.randomize_start: + time.sleep(random.randint(0, 60 * 3)) + self._run_for_rank(rank) else: # we still have to launch the job self.launch_job() @@ -244,7 +252,8 @@ def launch_job(self): # we actually save this (only once) to avoid race conditions json.dump(ranks_to_run, ranks_to_run_file) - max_array = min(len(ranks_to_run), self.max_array_size) if self.max_array_size != -1 else len(ranks_to_run) + nb_jobs_to_launch = math.ceil(len(ranks_to_run) / self.tasks_per_job) + max_array = min(nb_jobs_to_launch, self.max_array_size) if self.max_array_size != -1 else nb_jobs_to_launch # create the actual sbatch script launch_file_contents = self.get_launch_file_contents( @@ -261,7 +270,7 @@ def launch_job(self): # launch (possibly multiple) jobs launched_jobs = 0 - while launched_jobs * max_array < len(ranks_to_run): + while launched_jobs * max_array < nb_jobs_to_launch: if launched_jobs and self.max_array_launch_parallel and self.stagger_max_array_jobs > 0: time.sleep(self.stagger_max_array_jobs) args = [f"--export=ALL,RUN_OFFSET={launched_jobs}"]