Skip to content

Commit

Permalink
Refactor tomato reload (#96)
Browse files Browse the repository at this point in the history
* Implement first part of reload.

* Fix failing stop test.

* more stop job fixes

* fix one more test.

* More tests.

* That's ruff man
  • Loading branch information
PeterKraus authored Jul 30, 2024
1 parent 414eadd commit f0b44fa
Show file tree
Hide file tree
Showing 14 changed files with 482 additions and 209 deletions.
2 changes: 1 addition & 1 deletion docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ to check the status of and to cancel *jobs* in the queue.
*Jobs* submitted to the queue will remain in the queue until a *pipeline* meets all
of the following criteria:

- A *pipeline* where all of the ``techniques`` specified in the *payload* are matched
- A *pipeline* where all of the ``tasks`` specified in the *payload* are matched
by its ``capabilities`` must exist. Once the :mod:`tomato.daemon` finds such a
*pipeline*, the status of the *job* will change to ``qw``.
- The matching *pipeline* must contain a *sample* with a ``samplename`` that matches
Expand Down
120 changes: 115 additions & 5 deletions src/tomato/daemon/cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,25 +58,135 @@ def stop(msg: dict, daemon: Daemon) -> Reply:
io.store(daemon)
if any([pip.jobid is not None for pip in daemon.pips.values()]):
logger.error("cannot stop tomato-daemon as jobs are running")
return Reply(success=False, msg=daemon.status, data=daemon)
return Reply(success=False, msg="jobs are running", data=daemon.jobs)
else:
daemon.status = "stop"
logger.critical("stopping tomato-daemon")
return Reply(success=True, msg=daemon.status)
return Reply(success=True)


def setup(msg: dict, daemon: Daemon) -> Reply:
logger = logging.getLogger(f"{__name__}.setup")
logger.debug("%s", msg)
if daemon.status == "bootstrap":
for key in ["drvs", "devs", "pips", "cmps"]:
if key in msg:
setattr(daemon, key, msg[key])
setattr(daemon, key, msg[key])
logger.info("setup successful with pipelines: '%s'", daemon.pips.keys())
daemon.status = "running"
else:
# First, check that we're not touching anything associated with a running job
check_components = set()
check_devices = set()
check_drivers = set()
for dpip in daemon.pips.values():
if dpip.jobid is None:
continue
if dpip.name not in msg["pips"]:
return Reply(
success=False,
msg="reload would delete a running pipeline",
data=dpip,
)
pip = msg["pips"][dpip.name]
if pip.components != dpip.components:
return Reply(
success=False,
msg="reload would modify components of a running pipeline",
data=dpip,
)
check_components.update(dpip.components)

for cname in check_components:
dcomp = daemon.cmps[cname]
if cname not in msg["cmps"]:
return Reply(
success=False,
msg="reload would delete a component of a running pipeline",
data=dcomp,
)
comp = msg["cmps"][cname]
if (
dcomp.name != comp.name
or dcomp.driver != comp.driver
or dcomp.device != comp.device
or dcomp.address != comp.address
or dcomp.channel != comp.channel
or dcomp.role != comp.role
):
return Reply(
success=False,
msg="reload would modify a component of a running pipeline",
data=dcomp,
)
check_devices.add(dcomp.device)
check_drivers.add(dcomp.driver)

for dname in check_devices:
ddev = daemon.devs[dname]
if dname not in msg["devs"]:
return Reply(
success=False,
msg="reload would delete a device of a component in a running pipeline",
data=ddev,
)
dev = msg["devs"][dname]
if (
ddev.name != dev.name
or ddev.driver != dev.driver
or ddev.address != dev.address
or ddev.pollrate != dev.pollrate
or any(ch not in dev.channels for ch in ddev.channels)
):
return Reply(
success=False,
msg="reload would modify a device of a component in a running pipeline",
data=ddev,
)

for dname in check_drivers:
ddrv = daemon.drvs[dname]
if dname not in msg["drvs"]:
return Reply(
success=False,
msg="reload would delete a driver of a device in a running pipeline",
data=ddev,
)
drv = msg["drvs"][dname]
if ddrv.name != drv.name or ddrv.settings != drv.settings:
return Reply(
success=False,
msg="reload would modify a driver of a device in a running pipeline",
data=ddrv,
)

_api_reload(msg["drvs"], daemon.drvs, "driver", ["settings"])

_api_reload(msg["pips"], daemon.pips, "pipeline", ["components"])

attrlist = ["driver", "device", "address", "channel", "role"]
_api_reload(msg["cmps"], daemon.cmps, "component", attrlist)

_api_reload(msg["devs"], daemon.devs, "device", ["channels", "pollrate"])

logger.info("reload successful with pipelines: '%s'", daemon.pips.keys())
return Reply(success=True, msg=daemon.status, data=daemon)
return Reply(success=True, data=daemon)


def _api_reload(mdict: dict, ddict: dict, objname: str, attrlist: list[str]):
for obj in mdict.values():
if obj.name not in ddict:
logger.debug("adding new %s '%s'", objname, obj.name)
ddict[obj.name] = obj
continue
dobj = ddict[obj.name]
for attr in attrlist:
if getattr(dobj, attr) != getattr(obj, attr):
logger.debug("%s '%s.%s' updated", objname, dobj.name, attr)
setattr(dobj, attr, getattr(obj, attr))
for dobj in ddict.copy().values():
if dobj.name not in mdict:
logger.warning("removing unused %s '%s'", objname, dobj.name)
del ddict[dobj.name]


def pipeline(msg: dict, daemon: Daemon) -> Reply:
Expand Down
97 changes: 58 additions & 39 deletions src/tomato/daemon/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import argparse
from importlib import metadata
from datetime import datetime, timezone
from threading import currentThread
from threading import current_thread

import zmq
import psutil
Expand All @@ -25,6 +25,28 @@
logger = logging.getLogger(__name__)


def tomato_driver_bootstrap(
req: zmq.Socket, logger: logging.Logger, interface: ModelInterface, driver: str
):
logger.debug("getting daemon status")
req.send_pyobj(dict(cmd="status", with_data=True))
daemon = req.recv_pyobj().data
drv = daemon.drvs[driver]
interface.settings = drv.settings

logger.info("registering components for driver '%s'", driver)
for comp in daemon.cmps.values():
if comp.driver == driver:
logger.info("registering component '%s'", comp.name)
ret = interface.dev_register(address=comp.address, channel=comp.channel)
logger.debug(f"iface {ret=}")
params = dict(name=comp.name, capabilities=ret.data)
req.send_pyobj(dict(cmd="component", params=params))
ret = req.recv_pyobj()
logger.debug(f"daemon {ret=}")
logger.info("driver '%s' bootstrapped successfully", driver)


def tomato_driver() -> None:
"""
The function called when `tomato-driver` is executed.
Expand Down Expand Up @@ -89,33 +111,14 @@ def tomato_driver() -> None:
elif psutil.POSIX:
pid = os.getpid()

logger.debug("getting daemon status")
req.send_pyobj(
dict(cmd="status", with_data=True, sender=f"{__name__}.tomato_driver_bootstrap")
)
daemon = req.recv_pyobj().data
logger.debug(f"{daemon=}")

logger.info(f"attempting to spawn driver {args.driver!r}")
logger.info("attempting to create Interface for driver '%s'", args.driver)
Interface = driver_to_interface(args.driver)
if Interface is None:
logger.critical(f"library of driver {args.driver!r} not found")
logger.critical("class DriverInterface driver '%s' not found", args.driver)
return
drv = daemon.drvs[args.driver]
interface: ModelInterface = Interface(settings=drv.settings)

logger.info("registering components for driver '%s'", args.driver)
for comp in daemon.cmps.values():
if comp.driver == args.driver:
logger.info("registering component '%s'", comp.name)
ret = interface.dev_register(address=comp.address, channel=comp.channel)
logger.debug(f"iface {ret=}")
params = dict(name=comp.name, capabilities=ret.data)
req.send_pyobj(dict(cmd="component", params=params))
ret = req.recv_pyobj()
logger.debug(f"daemon {ret=}")

logger.info("driver '%s' bootstrapped successfully", args.driver)
interface: ModelInterface = Interface()
tomato_driver_bootstrap(req, logger, interface, args.driver)

params = dict(
name=args.driver,
Expand Down Expand Up @@ -152,6 +155,13 @@ def tomato_driver() -> None:
msg=f"status of driver {params['name']!r} is {status!r}",
data=dict(**params, status=status),
)
elif msg["cmd"] == "register":
tomato_driver_bootstrap(req, logger, interface, args.driver)
ret = Reply(
success=True,
msg="components re-registered successfully",
data=interface.devmap.keys(),
)
elif msg["cmd"] == "stop":
status = "stop"
ret = Reply(
Expand Down Expand Up @@ -220,10 +230,10 @@ def manager(port: int, timeout: int = 1000):
This manager ensures individual driver processes are (re-)spawned and instructed to
quit as necessary.
"""

sender = f"{__name__}.manager"
context = zmq.Context()
logger = logging.getLogger(f"{__name__}.manager")
thread = currentThread()
logger = logging.getLogger(sender)
thread = current_thread()
logger.info("launched successfully")
req = context.socket(zmq.REQ)
req.connect(f"tcp://127.0.0.1:{port}")
Expand All @@ -232,10 +242,10 @@ def manager(port: int, timeout: int = 1000):
to = timeout

while getattr(thread, "do_run"):
req.send_pyobj(dict(cmd="status", with_data=True, sender=f"{__name__}.manager"))
req.send_pyobj(dict(cmd="status", with_data=True, sender=sender))
events = dict(poller.poll(to))
if req not in events:
logger.warning(f"could not contact tomato-daemon in {to} ms")
logger.warning("could not contact tomato-daemon in %d ms", to)
to = to * 2
continue
elif to > timeout:
Expand All @@ -250,24 +260,33 @@ def manager(port: int, timeout: int = 1000):
else:
drv = daemon.drvs[driver]
if drv.pid is not None and not psutil.pid_exists(drv.pid):
logger.warning(f"respawning crashed driver {driver!r}")
logger.warning("respawning crashed driver '%s'", driver)
spawn_tomato_driver(daemon.port, driver, req, daemon.verbosity)
action_counter += 1
elif drv.pid is None and drv.spawned_at is None:
logger.debug(f"spawning driver {driver!r}")
logger.debug("spawning driver '%s'", driver)
spawn_tomato_driver(daemon.port, driver, req, daemon.verbosity)
action_counter += 1
elif drv.pid is None:
tspawn = datetime.fromisoformat(drv.spawned_at)
if (datetime.now(timezone.utc) - tspawn).seconds > 10:
logger.warning(f"respawning late driver {driver!r}")
spawn_tomato_driver(daemon.port, driver, req, daemon.verbosity)
action_counter += 1
logger.debug("tick")
if action_counter == 0:
contact_drivers = set()
for comp in daemon.cmps.values():
if comp.capabilities is None:
contact_drivers.add(comp.driver)
for driver in contact_drivers:
drv = daemon.drvs[driver]
if drv.port is None:
continue
logger.debug("contacting driver '%s' to re-register components", driver)
dreq = context.socket(zmq.REQ)
dreq.connect(f"tcp://127.0.0.1:{drv.port}")
dreq.send_pyobj(dict(cmd="register", params=None, sender=sender))
ret = dreq.recv_pyobj()
logger.debug(f"{ret=}")
dreq.close()
time.sleep(1 if action_counter > 0 else 0.1)

logger.info("instructed to quit")
req.send_pyobj(dict(cmd="status", with_data=True, sender=f"{__name__}.manager"))
req.send_pyobj(dict(cmd="status", with_data=True, sender=sender))
daemon = req.recv_pyobj().data
for driver in daemon.drvs.values():
logger.debug("stopping driver '%s' on port %d", driver.name, driver.port)
Expand Down
2 changes: 1 addition & 1 deletion src/tomato/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,5 @@ class Daemon(BaseModel, arbitrary_types_allowed=True):

class Reply(BaseModel):
success: bool
msg: str
msg: Optional[str] = None
data: Optional[Any] = None
Loading

0 comments on commit f0b44fa

Please sign in to comment.