Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store Job and tomato version in NetCDF files. #109

Merged
merged 5 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,23 @@ Each *job* stores its data and logs in its own *job* folder, which is a subfolde
Note that a *pipeline* dashboard functionality is planned for a future version of ``tomato``.


Final job data
**************
Final job data and metadata
***************************
By default, all data in the *job* folder is processed to create a NetCDF file. The NetCDF files can be read using :func:`xaray.open_datatree`, returning a :class:`xarray.DataTree`.

In the root node of the :class:`~xarray.DataTree`, a copy of the full *payload* is included, serialised as a json :class:`str`. Additionally, execution-specific metadata, such as the *pipeline* ``name``, and *job* submission/execution/completion time are stored on the root node, too.
In the root node of the :class:`~xarray.DataTree`, the :obj:`attrs` dictionary contains all **tomato**-relevant metadata. This currently includes:

The child nodes of the :class:`~xarray.DataTree` contain the actual data from each *pipeline* *component*, unit-annotated using the CF Metadata Conventions. The node names correspond to the ``role`` that *component* fullfils in a *pipeline*.
- ``tomato_version`` which is the version of **tomato** used to create the NetCDF file,
- ``tomato_Job`` which is the *job* object serialised as a json :class:`str`, containing the full *payload*, sample information, as well as *job* submission/execution/completion time.

The child nodes of the :class:`~xarray.DataTree` contain:

- the actual data from each *pipeline* *component*, unit-annotated using the CF Metadata Conventions. The node names correspond to the ``role`` that *component* fullfils in a *pipeline*.
- a ``tomato_Component`` entry in the :obj:`attrs` object, which is the *component* object serialised as a json :class:`str`, containing information about the *device* address and channel that define the *component*, the *driver* and *device* names, as well as the *component* capabilities.

.. note::

The ``tomato_Job`` and ``tomato_Component`` entries can be converted back to the source objects using :func:`tomato.models.Job.model_validate_json` and :func:`tomato.models.Component.model_validate_json`, respectively.

.. note::

Expand Down
4 changes: 3 additions & 1 deletion src/tomato/daemon/cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,13 @@ def job(msg: dict, daemon: Daemon) -> Reply:
daemon.jobs[jobid] = Job(id=jobid, **msg.get("params", {}))
logger.info("received job %d", jobid)
daemon.nextjob += 1
ret = daemon.jobs[jobid]
else:
for k, v in msg.get("params", {}).items():
logger.debug("setting job parameter %s.%s to %s", jobid, k, v)
setattr(daemon.jobs[jobid], k, v)
cjob = daemon.jobs[jobid]
ret = cjob
if cjob.status in {"c"}:
daemon.jobs[jobid] = CompletedJob(
id=cjob.id,
Expand All @@ -246,7 +248,7 @@ def job(msg: dict, daemon: Daemon) -> Reply:
jobpath=cjob.jobpath,
respath=cjob.respath,
)
return Reply(success=True, msg="job updated", data=daemon.jobs[jobid])
return Reply(success=True, msg="job updated", data=ret)


def driver(msg: dict, daemon: Daemon) -> Reply:
Expand Down
21 changes: 16 additions & 5 deletions src/tomato/daemon/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import pickle
import logging
import xarray as xr
import importlib.metadata
from pathlib import Path
from tomato.models import Daemon
from tomato.models import Daemon, Job

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -42,22 +43,32 @@ def load(daemon: Daemon):
daemon.status = "running"


def merge_netcdfs(jobpath: Path, outpath: Path):
def merge_netcdfs(job: Job, snapshot=False):
"""
Merges the individual pickled :class:`xr.Datasets` of each Component found in
`jobpath` into a single :class:`xr.DataTree`, which is then stored in the NetCDF file,
Merges the individual pickled :class:`xr.Datasets` of each Component found in :obj:`job.jobpath`
into a single :class:`xr.DataTree`, which is then stored in the NetCDF file,
using the Component `role` as the group label.
"""
logger = logging.getLogger(f"{__name__}.merge_netcdf")
logger.debug("opening datasets")
datasets = []
for fn in jobpath.glob("*.pkl"):
logger.debug(f"{job=}")
logger.debug(f"{job.jobpath=}")
for fn in Path(job.jobpath).glob("*.pkl"):
with pickle.load(fn.open("rb")) as ds:
datasets.append(ds)
logger.debug("creating a DataTree from %d groups", len(datasets))
dt = xr.DataTree.from_dict({ds.attrs["role"]: ds for ds in datasets})
logger.debug(f"{dt=}")
root_attrs = {
"tomato_version": importlib.metadata.version("tomato"),
"tomato_Job": job.model_dump_json(),
}
dt.attrs = root_attrs
outpath = job.snappath if snapshot else job.respath
logger.debug("saving DataTree into '%s'", outpath)
dt.to_netcdf(outpath, engine="h5netcdf")
logger.debug(f"{dt=}")


def data_to_pickle(ds: xr.Dataset, path: Path, role: str):
Expand Down
47 changes: 27 additions & 20 deletions src/tomato/daemon/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
import psutil

from tomato.daemon.io import merge_netcdfs, data_to_pickle
from tomato.models import Pipeline, Daemon, Component, Device, Driver
from tomato.models import Pipeline, Daemon, Component, Device, Driver, Job
from dgbowl_schemas.tomato import to_payload
from dgbowl_schemas.tomato.payload import Payload, Task
from dgbowl_schemas.tomato.payload import Task

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -111,7 +111,7 @@ def manage_running_pips(daemon: Daemon, req):
proc = psutil.Process(pid=job.pid)
kill_tomato_job(proc)
logger.info(f"job {job.id} with pid {job.pid} was terminated successfully")
merge_netcdfs(Path(job.jobpath), Path(job.respath))
merge_netcdfs(job)
reset = True
params = dict(status="cd")
# dead jobs marked as running (status == 'r') should be cleared
Expand Down Expand Up @@ -259,7 +259,6 @@ def manager(port: int, timeout: int = 500):
def lazy_pirate(
pyobj: Any, retries: int, timeout: int, address: str, context: zmq.Context
) -> Any:
logger.debug("Here")
req = context.socket(zmq.REQ)
req.connect(address)
poller = zmq.Poller()
Expand Down Expand Up @@ -384,19 +383,27 @@ def tomato_job() -> None:
respath = outpath / f"{prefix}.nc"
snappath = outpath / f"snapshot.{jobid}.nc"
params = dict(respath=str(respath), snappath=str(snappath), jobpath=str(jobpath))
lazy_pirate(pyobj=dict(cmd="job", id=jobid, params=params), **pkwargs)
ret = lazy_pirate(pyobj=dict(cmd="job", id=jobid, params=params), **pkwargs)
if ret.success is False:
logger.error("could not set job status for unknown reason")
return 1
job: Job = ret.data

logger.info("handing off to 'job_main_loop'")
logger.info("==============================")
job_main_loop(context, args.port, payload, pip, jobpath, snappath, logpath)
job_main_loop(context, args.port, job, pip, logpath)
logger.info("==============================")

merge_netcdfs(jobpath, respath)
logger.info("job finished successfully")
job.completed_at = str(datetime.now(timezone.utc))
job.status = "c"

logger.info("job finished successfully, attempting to set status to 'c'")
params = dict(status="c", completed_at=str(datetime.now(timezone.utc)))
logger.info("writing final data to a NetCDF file")
merge_netcdfs(job)

logger.info("attempting to set job status to 'c'")
params = dict(status=job.status, completed_at=job.completed_at)
ret = lazy_pirate(pyobj=dict(cmd="job", id=jobid, params=params), **pkwargs)
logger.debug(f"{ret=}")
if ret.success is False:
logger.error("could not set job status for unknown reason")
return 1
Expand Down Expand Up @@ -438,7 +445,7 @@ def job_thread(

kwargs = dict(address=component.address, channel=component.channel)

datapath = jobpath / f"{component.role}.pkl"
datapath = Path(jobpath) / f"{component.role}.pkl"
logger.debug("distributing tasks:")
for task in tasks:
logger.debug(f"{task=}")
Expand All @@ -463,7 +470,9 @@ def job_thread(
ret = req.recv_pyobj()
if ret.success:
logger.debug("pickling received data")
data_to_pickle(ret.data, datapath, role=component.role)
ds = ret.data
ds.attrs["tomato_Component"] = component.model_dump_json()
data_to_pickle(ds, datapath, role=component.role)
t0 += device.pollrate

logger.debug("polling component '%s' for task completion", component.role)
Expand All @@ -489,10 +498,8 @@ def job_thread(
def job_main_loop(
context: zmq.Context,
port: int,
payload: Payload,
job: Job,
pipname: str,
jobpath: Path,
snappath: Path,
logpath: Path,
) -> None:
"""
Expand All @@ -507,7 +514,7 @@ def job_main_loop(

while True:
req.send_pyobj(dict(cmd="status", sender=sender))
daemon = req.recv_pyobj().data
daemon: Daemon = req.recv_pyobj().data
if all([drv.port is not None for drv in daemon.drvs.values()]):
break
else:
Expand All @@ -519,7 +526,7 @@ def job_main_loop(

# collate steps by role
plan = {}
for step in payload.method:
for step in job.payload.method:
if step.component_tag not in plan:
plan[step.component_tag] = []
plan[step.component_tag].append(step)
Expand All @@ -540,20 +547,20 @@ def job_main_loop(
logger.debug(" driver=%s", driver)
threads[component.role] = Thread(
target=job_thread,
args=(tasks, component, device, driver, jobpath, logpath),
args=(tasks, component, device, driver, job.jobpath, logpath),
name="job-thread",
)
threads[component.role].start()

# wait until threads join or we're killed
snapshot = payload.settings.snapshot
snapshot = job.payload.settings.snapshot
t0 = time.perf_counter()
while True:
logger.debug("tick")
tN = time.perf_counter()
if snapshot is not None and tN - t0 > snapshot.frequency:
logger.debug("creating snapshot")
merge_netcdfs(jobpath, snappath)
merge_netcdfs(job, snapshot=True)
t0 += snapshot.frequency
joined = [proc.is_alive() is False for proc in threads.values()]
if all(joined):
Expand Down
48 changes: 26 additions & 22 deletions src/tomato/driverinterface_1_0/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class Attr(BaseModel):

class ModelInterface(metaclass=ABCMeta):
"""
An abstract base class specifying the a driver interface.
An abstract base class specifying the driver interface.

Individual driver modules should expose a :class:`DriverInterface` which inherits
from this abstract class. Only the methods of this class should be used to interact
Expand All @@ -58,9 +58,13 @@ class ModelInterface(metaclass=ABCMeta):
class DeviceManager(metaclass=ABCMeta):
"""
An abstract base class specifying a manager for an individual component.

This class should handle determining attributes and capabilities of the component,
the reading/writing of those attributes, processing of tasks, and caching and
returning of task data.
"""

driver: super
driver: "ModelInterface"
"""The parent :class:`DriverInterface` instance."""

data: dict[str, list]
Expand All @@ -70,7 +74,7 @@ class DeviceManager(metaclass=ABCMeta):
"""Lock object for thread-safe data manipulation."""

key: tuple
"""The key in :obj:`driver.devmap` referring to this object."""
"""The key in :obj:`self.driver.devmap` referring to this object."""

thread: Thread
"""The worker :class:`Thread`."""
Expand Down Expand Up @@ -203,8 +207,8 @@ def reset(self, **kwargs) -> None:

def CreateDeviceManager(self, key, **kwargs):
"""
A factory function which is used to pass this :class:`ModelInterface` to the new
:class:`DeviceManager` instance.
A factory function which is used to pass this instance of the :class:`ModelInterface`
to the new :class:`DeviceManager` instance.
"""
return self.DeviceManager(self, key, **kwargs)

Expand Down Expand Up @@ -272,20 +276,6 @@ def dev_reset(self, key: tuple, **kwargs: dict) -> Reply:
msg=f"component {key!r} reset successfully",
)

@in_devmap
def attrs(self, key: tuple, **kwargs: dict) -> Reply:
"""
Query available :class:`Attrs` on the specified device component.

Pass-through to the :func:`DeviceManager.attrs` function.
"""
ret = self.devmap[key].attrs(**kwargs)
return Reply(
success=True,
msg=f"attrs of component {key!r} are: {ret}",
data=ret,
)

@in_devmap
def dev_set_attr(self, attr: str, val: Any, key: tuple, **kwargs: dict) -> Reply:
"""
Expand All @@ -307,7 +297,7 @@ def dev_get_attr(self, attr: str, key: tuple, **kwargs: dict) -> Reply:
Get value of the :class:`Attr` from the specified device component.

Pass-through to the :func:`DeviceManager.get_attr` function. Units are not
returned; those can be queried for all :class:`Attrs` using :func:`attrs`.
returned; those can be queried for all :class:`Attrs` using :func:`self.attrs`.

"""
ret = self.devmap[key].get_attr(attr=attr, **kwargs)
Expand All @@ -322,8 +312,8 @@ def dev_status(self, key: tuple, **kwargs: dict) -> Reply:
"""
Get the status report from the specified device component.

Iterates over all :class:`Attrs` on the component that have `status=True` and
returns their values in a :class:`dict`.
Iterates over all :class:`Attrs` on the component that have ``status=True`` and
returns their values in the :obj:`Reply.data` as a :class:`dict`.
"""
ret = {}
for k, attr in self.devmap[key].attrs(key=key, **kwargs).items():
Expand Down Expand Up @@ -474,3 +464,17 @@ def capabilities(self, key: tuple, **kwargs) -> Reply:
msg=f"capabilities supported by component {key!r} are: {ret}",
data=ret,
)

@in_devmap
def attrs(self, key: tuple, **kwargs: dict) -> Reply:
"""
Query available :class:`Attrs` on the specified device component.

Pass-through to the :func:`DeviceManager.attrs` function.
"""
ret = self.devmap[key].attrs(**kwargs)
return Reply(
success=True,
msg=f"attrs of component {key!r} are: {ret}",
data=ret,
)
7 changes: 4 additions & 3 deletions src/tomato/ketchup/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from dgbowl_schemas.tomato import to_payload

from tomato.daemon.io import merge_netcdfs
from tomato.models import Reply, Daemon
from tomato.models import Reply, Daemon, Job

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -288,15 +288,16 @@ def snapshot(
Success: snapshot for job [3] created successfully

"""
jobs = status.data.jobs
jobs: list[Job] = status.data.jobs
for jobid in jobids:
if jobid not in jobs:
return Reply(success=False, msg=f"job {jobid} does not exist")
if jobs[jobid].status in {"q", "qw"}:
return Reply(success=False, msg=f"job {jobid} is still queued")

for jobid in jobids:
merge_netcdfs(Path(jobs[jobid].jobpath), Path(f"snapshot.{jobid}.nc"))
jobs[jobid].snappath = Path(f"snapshot.{jobid}.nc")
merge_netcdfs(jobs[jobid], snapshot=True)
if len(jobids) > 1:
msg = f"snapshot for jobs {jobids} created successfully"
else:
Expand Down
Loading