From a941e1335dedf7be732f75ff2175998431bf82de Mon Sep 17 00:00:00 2001 From: Peter Kraus Date: Wed, 20 Nov 2024 12:26:26 +0100 Subject: [PATCH] Store job metadata in netcdf. --- src/tomato/daemon/cmd.py | 4 +++- src/tomato/daemon/io.py | 21 +++++++++++++---- src/tomato/daemon/job.py | 43 +++++++++++++++++++--------------- src/tomato/ketchup/__init__.py | 7 +++--- 4 files changed, 47 insertions(+), 28 deletions(-) diff --git a/src/tomato/daemon/cmd.py b/src/tomato/daemon/cmd.py index 4a9e257..f874bf4 100644 --- a/src/tomato/daemon/cmd.py +++ b/src/tomato/daemon/cmd.py @@ -232,11 +232,13 @@ def job(msg: dict, daemon: Daemon) -> Reply: daemon.jobs[jobid] = Job(id=jobid, **msg.get("params", {})) logger.info("received job %d", jobid) daemon.nextjob += 1 + ret = daemon.jobs[jobid] else: for k, v in msg.get("params", {}).items(): logger.debug("setting job parameter %s.%s to %s", jobid, k, v) setattr(daemon.jobs[jobid], k, v) cjob = daemon.jobs[jobid] + ret = cjob if cjob.status in {"c"}: daemon.jobs[jobid] = CompletedJob( id=cjob.id, @@ -246,7 +248,7 @@ def job(msg: dict, daemon: Daemon) -> Reply: jobpath=cjob.jobpath, respath=cjob.respath, ) - return Reply(success=True, msg="job updated", data=daemon.jobs[jobid]) + return Reply(success=True, msg="job updated", data=ret) def driver(msg: dict, daemon: Daemon) -> Reply: diff --git a/src/tomato/daemon/io.py b/src/tomato/daemon/io.py index 82a5d4d..5fc8103 100644 --- a/src/tomato/daemon/io.py +++ b/src/tomato/daemon/io.py @@ -9,8 +9,9 @@ import pickle import logging import xarray as xr +import importlib.metadata from pathlib import Path -from tomato.models import Daemon +from tomato.models import Daemon, Job logger = logging.getLogger(__name__) @@ -42,22 +43,32 @@ def load(daemon: Daemon): daemon.status = "running" -def merge_netcdfs(jobpath: Path, outpath: Path): +def merge_netcdfs(job: Job, snapshot=False): """ - Merges the individual pickled :class:`xr.Datasets` of each Component found in - `jobpath` into a single :class:`xr.DataTree`, which is then stored in the NetCDF file, + Merges the individual pickled :class:`xr.Datasets` of each Component found in :obj:`job.jobpath` + into a single :class:`xr.DataTree`, which is then stored in the NetCDF file, using the Component `role` as the group label. """ logger = logging.getLogger(f"{__name__}.merge_netcdf") logger.debug("opening datasets") datasets = [] - for fn in jobpath.glob("*.pkl"): + logger.debug(f"{job=}") + logger.debug(f"{job.jobpath=}") + for fn in Path(job.jobpath).glob("*.pkl"): with pickle.load(fn.open("rb")) as ds: datasets.append(ds) logger.debug("creating a DataTree from %d groups", len(datasets)) dt = xr.DataTree.from_dict({ds.attrs["role"]: ds for ds in datasets}) + logger.debug(f"{dt=}") + root_attrs = { + "tomato_version": importlib.metadata.version("tomato"), + "tomato_Job": job.model_dump_json(), + } + dt.attrs = root_attrs + outpath = job.snappath if snapshot else job.respath logger.debug("saving DataTree into '%s'", outpath) dt.to_netcdf(outpath, engine="h5netcdf") + logger.debug(f"{dt=}") def data_to_pickle(ds: xr.Dataset, path: Path, role: str): diff --git a/src/tomato/daemon/job.py b/src/tomato/daemon/job.py index e689f3d..5d35784 100644 --- a/src/tomato/daemon/job.py +++ b/src/tomato/daemon/job.py @@ -27,9 +27,9 @@ import psutil from tomato.daemon.io import merge_netcdfs, data_to_pickle -from tomato.models import Pipeline, Daemon, Component, Device, Driver +from tomato.models import Pipeline, Daemon, Component, Device, Driver, Job from dgbowl_schemas.tomato import to_payload -from dgbowl_schemas.tomato.payload import Payload, Task +from dgbowl_schemas.tomato.payload import Task logger = logging.getLogger(__name__) @@ -111,7 +111,7 @@ def manage_running_pips(daemon: Daemon, req): proc = psutil.Process(pid=job.pid) kill_tomato_job(proc) logger.info(f"job {job.id} with pid {job.pid} was terminated successfully") - merge_netcdfs(Path(job.jobpath), Path(job.respath)) + merge_netcdfs(job) reset = True params = dict(status="cd") # dead jobs marked as running (status == 'r') should be cleared @@ -259,7 +259,6 @@ def manager(port: int, timeout: int = 500): def lazy_pirate( pyobj: Any, retries: int, timeout: int, address: str, context: zmq.Context ) -> Any: - logger.debug("Here") req = context.socket(zmq.REQ) req.connect(address) poller = zmq.Poller() @@ -384,19 +383,27 @@ def tomato_job() -> None: respath = outpath / f"{prefix}.nc" snappath = outpath / f"snapshot.{jobid}.nc" params = dict(respath=str(respath), snappath=str(snappath), jobpath=str(jobpath)) - lazy_pirate(pyobj=dict(cmd="job", id=jobid, params=params), **pkwargs) + ret = lazy_pirate(pyobj=dict(cmd="job", id=jobid, params=params), **pkwargs) + if ret.success is False: + logger.error("could not set job status for unknown reason") + return 1 + job: Job = ret.data logger.info("handing off to 'job_main_loop'") logger.info("==============================") - job_main_loop(context, args.port, payload, pip, jobpath, snappath, logpath) + job_main_loop(context, args.port, job, pip, logpath) logger.info("==============================") - merge_netcdfs(jobpath, respath) + logger.info("job finished successfully") + job.completed_at = str(datetime.now(timezone.utc)) + job.status = "c" - logger.info("job finished successfully, attempting to set status to 'c'") - params = dict(status="c", completed_at=str(datetime.now(timezone.utc))) + logger.info("writing final data to a NetCDF file") + merge_netcdfs(job) + + logger.info("attempting to set job status to 'c'") + params = dict(status=job.status, completed_at=job.completed_at) ret = lazy_pirate(pyobj=dict(cmd="job", id=jobid, params=params), **pkwargs) - logger.debug(f"{ret=}") if ret.success is False: logger.error("could not set job status for unknown reason") return 1 @@ -438,7 +445,7 @@ def job_thread( kwargs = dict(address=component.address, channel=component.channel) - datapath = jobpath / f"{component.role}.pkl" + datapath = Path(jobpath) / f"{component.role}.pkl" logger.debug("distributing tasks:") for task in tasks: logger.debug(f"{task=}") @@ -489,10 +496,8 @@ def job_thread( def job_main_loop( context: zmq.Context, port: int, - payload: Payload, + job: Job, pipname: str, - jobpath: Path, - snappath: Path, logpath: Path, ) -> None: """ @@ -507,7 +512,7 @@ def job_main_loop( while True: req.send_pyobj(dict(cmd="status", sender=sender)) - daemon = req.recv_pyobj().data + daemon: Daemon = req.recv_pyobj().data if all([drv.port is not None for drv in daemon.drvs.values()]): break else: @@ -519,7 +524,7 @@ def job_main_loop( # collate steps by role plan = {} - for step in payload.method: + for step in job.payload.method: if step.component_tag not in plan: plan[step.component_tag] = [] plan[step.component_tag].append(step) @@ -540,20 +545,20 @@ def job_main_loop( logger.debug(" driver=%s", driver) threads[component.role] = Thread( target=job_thread, - args=(tasks, component, device, driver, jobpath, logpath), + args=(tasks, component, device, driver, job.jobpath, logpath), name="job-thread", ) threads[component.role].start() # wait until threads join or we're killed - snapshot = payload.settings.snapshot + snapshot = job.payload.settings.snapshot t0 = time.perf_counter() while True: logger.debug("tick") tN = time.perf_counter() if snapshot is not None and tN - t0 > snapshot.frequency: logger.debug("creating snapshot") - merge_netcdfs(jobpath, snappath) + merge_netcdfs(job, snapshot=True) t0 += snapshot.frequency joined = [proc.is_alive() is False for proc in threads.values()] if all(joined): diff --git a/src/tomato/ketchup/__init__.py b/src/tomato/ketchup/__init__.py index 905f3fe..bc61196 100644 --- a/src/tomato/ketchup/__init__.py +++ b/src/tomato/ketchup/__init__.py @@ -25,7 +25,7 @@ from dgbowl_schemas.tomato import to_payload from tomato.daemon.io import merge_netcdfs -from tomato.models import Reply, Daemon +from tomato.models import Reply, Daemon, Job log = logging.getLogger(__name__) @@ -288,7 +288,7 @@ def snapshot( Success: snapshot for job [3] created successfully """ - jobs = status.data.jobs + jobs: list[Job] = status.data.jobs for jobid in jobids: if jobid not in jobs: return Reply(success=False, msg=f"job {jobid} does not exist") @@ -296,7 +296,8 @@ def snapshot( return Reply(success=False, msg=f"job {jobid} is still queued") for jobid in jobids: - merge_netcdfs(Path(jobs[jobid].jobpath), Path(f"snapshot.{jobid}.nc")) + jobs[jobid].snappath = Path(f"snapshot.{jobid}.nc") + merge_netcdfs(jobs[jobid], snapshot=True) if len(jobids) > 1: msg = f"snapshot for jobs {jobids} created successfully" else: