Skip to content

Commit

Permalink
Store job metadata in netcdf.
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterKraus committed Nov 20, 2024
1 parent fc47ee2 commit a941e13
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 28 deletions.
4 changes: 3 additions & 1 deletion src/tomato/daemon/cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,13 @@ def job(msg: dict, daemon: Daemon) -> Reply:
daemon.jobs[jobid] = Job(id=jobid, **msg.get("params", {}))
logger.info("received job %d", jobid)
daemon.nextjob += 1
ret = daemon.jobs[jobid]
else:
for k, v in msg.get("params", {}).items():
logger.debug("setting job parameter %s.%s to %s", jobid, k, v)
setattr(daemon.jobs[jobid], k, v)
cjob = daemon.jobs[jobid]
ret = cjob
if cjob.status in {"c"}:
daemon.jobs[jobid] = CompletedJob(
id=cjob.id,
Expand All @@ -246,7 +248,7 @@ def job(msg: dict, daemon: Daemon) -> Reply:
jobpath=cjob.jobpath,
respath=cjob.respath,
)
return Reply(success=True, msg="job updated", data=daemon.jobs[jobid])
return Reply(success=True, msg="job updated", data=ret)


def driver(msg: dict, daemon: Daemon) -> Reply:
Expand Down
21 changes: 16 additions & 5 deletions src/tomato/daemon/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import pickle
import logging
import xarray as xr
import importlib.metadata
from pathlib import Path
from tomato.models import Daemon
from tomato.models import Daemon, Job

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -42,22 +43,32 @@ def load(daemon: Daemon):
daemon.status = "running"


def merge_netcdfs(jobpath: Path, outpath: Path):
def merge_netcdfs(job: Job, snapshot=False):
"""
Merges the individual pickled :class:`xr.Datasets` of each Component found in
`jobpath` into a single :class:`xr.DataTree`, which is then stored in the NetCDF file,
Merges the individual pickled :class:`xr.Datasets` of each Component found in :obj:`job.jobpath`
into a single :class:`xr.DataTree`, which is then stored in the NetCDF file,
using the Component `role` as the group label.
"""
logger = logging.getLogger(f"{__name__}.merge_netcdf")
logger.debug("opening datasets")
datasets = []
for fn in jobpath.glob("*.pkl"):
logger.debug(f"{job=}")
logger.debug(f"{job.jobpath=}")
for fn in Path(job.jobpath).glob("*.pkl"):
with pickle.load(fn.open("rb")) as ds:
datasets.append(ds)
logger.debug("creating a DataTree from %d groups", len(datasets))
dt = xr.DataTree.from_dict({ds.attrs["role"]: ds for ds in datasets})
logger.debug(f"{dt=}")
root_attrs = {
"tomato_version": importlib.metadata.version("tomato"),
"tomato_Job": job.model_dump_json(),
}
dt.attrs = root_attrs
outpath = job.snappath if snapshot else job.respath
logger.debug("saving DataTree into '%s'", outpath)
dt.to_netcdf(outpath, engine="h5netcdf")
logger.debug(f"{dt=}")


def data_to_pickle(ds: xr.Dataset, path: Path, role: str):
Expand Down
43 changes: 24 additions & 19 deletions src/tomato/daemon/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
import psutil

from tomato.daemon.io import merge_netcdfs, data_to_pickle
from tomato.models import Pipeline, Daemon, Component, Device, Driver
from tomato.models import Pipeline, Daemon, Component, Device, Driver, Job
from dgbowl_schemas.tomato import to_payload
from dgbowl_schemas.tomato.payload import Payload, Task
from dgbowl_schemas.tomato.payload import Task

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -111,7 +111,7 @@ def manage_running_pips(daemon: Daemon, req):
proc = psutil.Process(pid=job.pid)
kill_tomato_job(proc)
logger.info(f"job {job.id} with pid {job.pid} was terminated successfully")
merge_netcdfs(Path(job.jobpath), Path(job.respath))
merge_netcdfs(job)
reset = True
params = dict(status="cd")
# dead jobs marked as running (status == 'r') should be cleared
Expand Down Expand Up @@ -259,7 +259,6 @@ def manager(port: int, timeout: int = 500):
def lazy_pirate(
pyobj: Any, retries: int, timeout: int, address: str, context: zmq.Context
) -> Any:
logger.debug("Here")
req = context.socket(zmq.REQ)
req.connect(address)
poller = zmq.Poller()
Expand Down Expand Up @@ -384,19 +383,27 @@ def tomato_job() -> None:
respath = outpath / f"{prefix}.nc"
snappath = outpath / f"snapshot.{jobid}.nc"
params = dict(respath=str(respath), snappath=str(snappath), jobpath=str(jobpath))
lazy_pirate(pyobj=dict(cmd="job", id=jobid, params=params), **pkwargs)
ret = lazy_pirate(pyobj=dict(cmd="job", id=jobid, params=params), **pkwargs)
if ret.success is False:
logger.error("could not set job status for unknown reason")
return 1
job: Job = ret.data

logger.info("handing off to 'job_main_loop'")
logger.info("==============================")
job_main_loop(context, args.port, payload, pip, jobpath, snappath, logpath)
job_main_loop(context, args.port, job, pip, logpath)
logger.info("==============================")

merge_netcdfs(jobpath, respath)
logger.info("job finished successfully")
job.completed_at = str(datetime.now(timezone.utc))
job.status = "c"

logger.info("job finished successfully, attempting to set status to 'c'")
params = dict(status="c", completed_at=str(datetime.now(timezone.utc)))
logger.info("writing final data to a NetCDF file")
merge_netcdfs(job)

logger.info("attempting to set job status to 'c'")
params = dict(status=job.status, completed_at=job.completed_at)
ret = lazy_pirate(pyobj=dict(cmd="job", id=jobid, params=params), **pkwargs)
logger.debug(f"{ret=}")
if ret.success is False:
logger.error("could not set job status for unknown reason")
return 1
Expand Down Expand Up @@ -438,7 +445,7 @@ def job_thread(

kwargs = dict(address=component.address, channel=component.channel)

datapath = jobpath / f"{component.role}.pkl"
datapath = Path(jobpath) / f"{component.role}.pkl"
logger.debug("distributing tasks:")
for task in tasks:
logger.debug(f"{task=}")
Expand Down Expand Up @@ -489,10 +496,8 @@ def job_thread(
def job_main_loop(
context: zmq.Context,
port: int,
payload: Payload,
job: Job,
pipname: str,
jobpath: Path,
snappath: Path,
logpath: Path,
) -> None:
"""
Expand All @@ -507,7 +512,7 @@ def job_main_loop(

while True:
req.send_pyobj(dict(cmd="status", sender=sender))
daemon = req.recv_pyobj().data
daemon: Daemon = req.recv_pyobj().data
if all([drv.port is not None for drv in daemon.drvs.values()]):
break
else:
Expand All @@ -519,7 +524,7 @@ def job_main_loop(

# collate steps by role
plan = {}
for step in payload.method:
for step in job.payload.method:
if step.component_tag not in plan:
plan[step.component_tag] = []
plan[step.component_tag].append(step)
Expand All @@ -540,20 +545,20 @@ def job_main_loop(
logger.debug(" driver=%s", driver)
threads[component.role] = Thread(
target=job_thread,
args=(tasks, component, device, driver, jobpath, logpath),
args=(tasks, component, device, driver, job.jobpath, logpath),
name="job-thread",
)
threads[component.role].start()

# wait until threads join or we're killed
snapshot = payload.settings.snapshot
snapshot = job.payload.settings.snapshot
t0 = time.perf_counter()
while True:
logger.debug("tick")
tN = time.perf_counter()
if snapshot is not None and tN - t0 > snapshot.frequency:
logger.debug("creating snapshot")
merge_netcdfs(jobpath, snappath)
merge_netcdfs(job, snapshot=True)
t0 += snapshot.frequency
joined = [proc.is_alive() is False for proc in threads.values()]
if all(joined):
Expand Down
7 changes: 4 additions & 3 deletions src/tomato/ketchup/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from dgbowl_schemas.tomato import to_payload

from tomato.daemon.io import merge_netcdfs
from tomato.models import Reply, Daemon
from tomato.models import Reply, Daemon, Job

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -288,15 +288,16 @@ def snapshot(
Success: snapshot for job [3] created successfully
"""
jobs = status.data.jobs
jobs: list[Job] = status.data.jobs
for jobid in jobids:
if jobid not in jobs:
return Reply(success=False, msg=f"job {jobid} does not exist")
if jobs[jobid].status in {"q", "qw"}:
return Reply(success=False, msg=f"job {jobid} is still queued")

for jobid in jobids:
merge_netcdfs(Path(jobs[jobid].jobpath), Path(f"snapshot.{jobid}.nc"))
jobs[jobid].snappath = Path(f"snapshot.{jobid}.nc")
merge_netcdfs(jobs[jobid], snapshot=True)
if len(jobids) > 1:
msg = f"snapshot for jobs {jobids} created successfully"
else:
Expand Down

0 comments on commit a941e13

Please sign in to comment.