From 1c6a6083b5dcb412bdcb899d94c8ec15304c2999 Mon Sep 17 00:00:00 2001 From: Peter Kraus Date: Fri, 31 May 2024 14:10:49 +0200 Subject: [PATCH 01/17] Implement DriverInterface. --- src/tomato/daemon/driver.py | 22 +-- .../drivers/example_counter/__init__.py | 95 ++++++------- src/tomato/models.py | 128 ++++++++++++++++++ 3 files changed, 188 insertions(+), 57 deletions(-) diff --git a/src/tomato/daemon/driver.py b/src/tomato/daemon/driver.py index a2377faa..84b9fc9c 100644 --- a/src/tomato/daemon/driver.py +++ b/src/tomato/daemon/driver.py @@ -101,14 +101,14 @@ def tomato_driver() -> None: return kwargs = dict(settings=daemon.drvs[args.driver].settings) - driver = getattr(tomato.drivers, args.driver).Driver(**kwargs) + interface = getattr(tomato.drivers, args.driver).DriverInterface(**kwargs) logger.info(f"registering devices in driver {args.driver!r}") for dev in daemon.devs.values(): if dev.driver == args.driver: for channel in dev.channels: - driver.dev_register(address=dev.address, channel=channel) - logger.debug(f"{driver.devmap=}") + interface.dev_register(address=dev.address, channel=channel) + logger.debug(f"{interface.devmap=}") logger.info(f"driver {args.driver!r} bootstrapped successfully") @@ -117,7 +117,7 @@ def tomato_driver() -> None: port=port, pid=pid, connected_at=str(datetime.now(timezone.utc)), - settings=driver.settings, + settings=interface.settings, ) req.send_pyobj( dict(cmd="driver", params=params, sender=f"{__name__}.tomato_driver") @@ -155,26 +155,26 @@ def tomato_driver() -> None: data=dict(status=status, driver=args.driver), ) elif msg["cmd"] == "settings": - driver.settings = msg["params"] - params["settings"] = driver.settings + interface.settings = msg["params"] + params["settings"] = interface.settings ret = Reply( success=True, msg="settings received", data=msg.get("params"), ) elif msg["cmd"] == "dev_register": - driver.dev_register(**msg["params"]) + interface.dev_register(**msg["params"]) ret = Reply( success=True, msg="device registered", data=msg.get("params"), ) elif msg["cmd"] == "task_status": - ret = driver.task_status(**msg["params"]) + ret = interface.task_status(**msg["params"]) elif msg["cmd"] == "task_start": - ret = driver.task_start(**msg["params"]) + ret = interface.task_start(**msg["params"]) elif msg["cmd"] == "task_data": - ret = driver.task_data(**msg["params"]) + ret = interface.task_data(**msg["params"]) logger.debug(f"{ret=}") rep.send_pyobj(ret) if status == "stop": @@ -182,7 +182,7 @@ def tomato_driver() -> None: logger.info(f"driver {args.driver!r} is beginning teardown") - driver.teardown() + interface.teardown() logger.critical(f"driver {args.driver!r} is quitting") diff --git a/src/tomato/drivers/example_counter/__init__.py b/src/tomato/drivers/example_counter/__init__.py index 31e59611..6f4d0d94 100644 --- a/src/tomato/drivers/example_counter/__init__.py +++ b/src/tomato/drivers/example_counter/__init__.py @@ -5,7 +5,7 @@ from functools import wraps from tomato.drivers.example_counter.counter import Counter -from tomato.models import Reply +from tomato.models import Reply, DriverInterface from xarray import Dataset logger = logging.getLogger(__name__) @@ -24,8 +24,8 @@ def wrapper(self, **kwargs): return wrapper -class Driver: - class Device: +class DriverInterface(DriverInterface): + class DeviceInterface: dev: Counter conn: Connection proc: Process @@ -35,49 +35,48 @@ def __init__(self): self.conn, conn = Pipe() self.proc = Process(target=self.dev.run_counter, args=(conn,)) - devmap: dict[tuple, Device] + devmap: dict[tuple, DeviceInterface] settings: dict - attrs: dict = dict( - delay=dict(type=float, rw=True), - time=dict(type=float, rw=True), - started=dict(type=bool, rw=True), - val=dict(type=int, rw=False), - ) - - tasks: dict = dict( - count=dict( - time=dict(type=float), - delay=dict(type=float), - ), - random=dict( - time=dict(type=float), - delay=dict(type=float), - min=dict(type=float), - max=dict(type=float), - ), - ) - - def __init__(self, settings=None): - self.devmap = {} - self.settings = settings if settings is not None else {} + def attrs(self, **kwargs) -> dict: + return dict( + delay=dict(type=float, rw=True), + time=dict(type=float, rw=True), + started=dict(type=bool, rw=True), + val=dict(type=int, rw=False), + ) + + def tasks(self, **kwargs) -> dict: + return dict( + count=dict( + time=dict(type=float), + delay=dict(type=float), + ), + random=dict( + time=dict(type=float), + delay=dict(type=float), + min=dict(type=float), + max=dict(type=float), + ), + ) def dev_register(self, address: str, channel: int, **kwargs): key = (address, channel) - self.devmap[key] = self.Device() + self.devmap[key] = self.DeviceInterface() self.devmap[key].proc.start() @in_devmap - def dev_attr_set(self, attr: str, val: Any, address: str, channel: int, **kwargs): + def dev_set_attr(self, attr: str, val: Any, address: str, channel: int, **kwargs): key = (address, channel) - if attr in self.attrs: - if self.attrs[attr]["rw"] and isinstance(val, self.attrs[attr]["type"]): + if attr in self.attrs(): + params = self.attrs()[attr] + if params["rw"] and isinstance(val, params["type"]): self.devmap[key].conn.send(("set", attr, val)) @in_devmap - def dev_attr_get(self, attr: str, address: str, channel: int, **kwargs): + def dev_get_attr(self, attr: str, address: str, channel: int, **kwargs): key = (address, channel) - if attr in self.attrs: + if attr in self.attrs(): self.devmap[key].conn.send(("get", attr, None)) return self.devmap[key].conn.recv() @@ -87,24 +86,16 @@ def dev_status(self, address: str, channel: int, **kwargs): self.devmap[key].conn.send(("status", None, None)) return self.devmap[key].conn.recv() - @in_devmap - def task_status(self, address: str, channel: int): - started = self.dev_attr_get(attr="started", address=address, channel=channel) - if not started: - return Reply(success=True, msg="ready") - else: - return Reply(success=True, msg="running") - @in_devmap def task_start(self, address: str, channel: int, task: str, **kwargs): - if task not in self.tasks: + if task not in self.tasks(): return Reply( success=False, msg=f"unknown task {task!r} requested", - data=self.tasks, + data=self.tasks(), ) - reqs = self.tasks[task] + reqs = self.tasks()[task] for k, v in reqs.items(): if k not in kwargs and "default" not in v: logger.critical("Somehow we're here") @@ -123,14 +114,22 @@ def task_start(self, address: str, channel: int, task: str, **kwargs): msg=f"parameter {k!r} is wrong type", data=reqs, ) - self.dev_attr_set(attr=k, val=val, address=address, channel=channel) - self.dev_attr_set(attr="started", val=True, address=address, channel=channel) + self.dev_set_attr(attr=k, val=val, address=address, channel=channel) + self.dev_set_attr(attr="started", val=True, address=address, channel=channel) return Reply( success=True, msg=f"task {task!r} started successfully", data=kwargs, ) + @in_devmap + def task_status(self, address: str, channel: int): + started = self.dev_get_attr(attr="started", address=address, channel=channel) + if not started: + return Reply(success=True, msg="ready") + else: + return Reply(success=True, msg="running") + @in_devmap def task_data(self, address: str, channel: int, **kwargs): key = (address, channel) @@ -171,3 +170,7 @@ def teardown(self): logger.error(f"device {key!r} is still alive") else: logger.debug(f"device {key!r} successfully closed") + + +if __name__ == "__main__": + test = DriverInterface() diff --git a/src/tomato/models.py b/src/tomato/models.py index e27d99cb..5852a2ec 100644 --- a/src/tomato/models.py +++ b/src/tomato/models.py @@ -1,6 +1,8 @@ from pydantic import BaseModel, Field from typing import Optional, Any, Mapping, Sequence, Literal from pathlib import Path +from abc import ABCMeta, abstractmethod +import xarray as xr class Driver(BaseModel): @@ -68,3 +70,129 @@ class Reply(BaseModel): success: bool msg: str data: Optional[Any] = None + + +class DriverInterface(metaclass=ABCMeta): + class DeviceInterface(metaclass=ABCMeta): + """Class used to implement management of each individual device.""" + + pass + + devmap: dict[tuple, DeviceInterface] + """Map of registered devices, the tuple keys are components = (address, channel)""" + + settings: dict[str, str] + """A settings map to contain driver-specific settings such as `dllpath` for BioLogic""" + + def __init__(self, settings=None): + self.devmap = {} + self.settings = settings if settings is not None else {} + + def dev_register(self, address: str, channel: int, **kwargs: dict) -> None: + """ + Register a Device and its Component in this DriverInterface, creating a + :obj:`self.DeviceInterface` object in the :obj:`self.devmap` if necessary, or + updating existing channels in :obj:`self.devmap`. + """ + self.devmap[(address, channel)] = self.DeviceInterface(**kwargs) + + def dev_teardown(self, address: str, channel: int, **kwargs: dict) -> None: + """ + Emergency stop function. Set the device into a documented, safe state. + + The function is to be only called in case of critical errors, not as part of + normal operation. + """ + pass + + @abstractmethod + def attrs(self, address: str, channel: int, **kwargs) -> dict: + """ + Function that returns all gettable and settable attributes, their rw status, + and whether they are to be printed in `dev_status`. + + This is the "low level" control interface, intended for the device dashboard. + + Example: + -------- + return dict( + delay = dict(type=float, rw=True, status=False), + time = dict(type=float, rw=True, status=False), + started = dict(type=bool, rw=True, status=True), + val = dict(type=int, rw=False, status=True), + ) + """ + pass + + @abstractmethod + def dev_set_attr(self, attr: str, val: Any, address: str, channel: int, **kwargs): + """Set the value of a read-write attr on a Component""" + pass + + @abstractmethod + def dev_get_attr(self, attr: str, address: str, channel: int, **kwargs): + """Get the value of any attr from a Component""" + pass + + def dev_status(self, address: str, channel: int, **kwargs) -> dict[str, Any]: + """Get a status report from a Component""" + ret = {} + for k, v in self.attrs(address=address, channel=channel, **kwargs).items(): + if v.status: + ret[k] = self.dev_get_attr( + attr=k, address=address, channel=channel, **kwargs + ) + return ret + + # @abstractmethod + # def dev_get_data(self, address: str, channel: int, **kwargs): + # """Get a data report from a Component""" + # pass + + @abstractmethod + def tasks(self, address: str, channel: int, **kwargs) -> dict: + """ + Function that returns all tasks that can be submitted to the Device. This + implements the driver specific language. Each task in tasks can only contain + elements present in :func:`self.attrs`. + + Example: + return dict( + count = dict(time = dict(type=float), delay = dict(type=float), + ) + """ + pass + + @abstractmethod + def task_start(self, address: str, channel: int, task: str, **kwargs) -> None: + """start a task on a (ready) component""" + pass + + @abstractmethod + def task_status(self, address: str, channel: int) -> Literal["running", "ready"]: + """check task status of the component""" + pass + + @abstractmethod + def task_data(self, address: str, channel: int, **kwargs) -> xr.Dataset: + """get any cached data for the current task on the component""" + pass + + # @abstractmethod + # def task_stop(self, address: str, channel: int) -> xr.Dataset: + # """stops the current task, making the component ready and returning any data""" + # pass + + @abstractmethod + def status(self) -> dict: + """return status info of the driver""" + pass + + @abstractmethod + def teardown(self) -> None: + """ + Stop all tasks, tear down all devices, close all processes. + + Users can assume the devices are put in a safe state (valves closed, power off). + """ + pass From 42b77300e267fa58e155c87abc87b38bf4587a45 Mon Sep 17 00:00:00 2001 From: Peter Kraus Date: Fri, 31 May 2024 14:44:56 +0200 Subject: [PATCH 02/17] Two more functions. --- .../drivers/example_counter/__init__.py | 18 +++++-- src/tomato/models.py | 48 +++++++++++-------- 2 files changed, 43 insertions(+), 23 deletions(-) diff --git a/src/tomato/drivers/example_counter/__init__.py b/src/tomato/drivers/example_counter/__init__.py index 6f4d0d94..43f7ce0a 100644 --- a/src/tomato/drivers/example_counter/__init__.py +++ b/src/tomato/drivers/example_counter/__init__.py @@ -40,10 +40,10 @@ def __init__(self): def attrs(self, **kwargs) -> dict: return dict( - delay=dict(type=float, rw=True), - time=dict(type=float, rw=True), - started=dict(type=bool, rw=True), - val=dict(type=int, rw=False), + delay=dict(type=float, rw=True, status=False, data=True), + time=dict(type=float, rw=True, status=False, data=True), + started=dict(type=bool, rw=True, status=True, data=False), + val=dict(type=int, rw=False, status=True, data=True), ) def tasks(self, **kwargs) -> dict: @@ -130,6 +130,16 @@ def task_status(self, address: str, channel: int): else: return Reply(success=True, msg="running") + @in_devmap + def task_stop(self, address: str, channel: int): + self.dev_set_attr(attr="started", val=False, address=address, channel=channel) + + ret = self.task_data(self, address, channel) + if ret.success: + return Reply(success=True, msg=f"task stopped, {ret.msg}", data=ret.data) + else: + return Reply(success=True, msg=f"task stopped, {ret.msg}") + @in_devmap def task_data(self, address: str, channel: int, **kwargs): key = (address, channel) diff --git a/src/tomato/models.py b/src/tomato/models.py index 5852a2ec..f120af70 100644 --- a/src/tomato/models.py +++ b/src/tomato/models.py @@ -109,18 +109,21 @@ def dev_teardown(self, address: str, channel: int, **kwargs: dict) -> None: def attrs(self, address: str, channel: int, **kwargs) -> dict: """ Function that returns all gettable and settable attributes, their rw status, - and whether they are to be printed in `dev_status`. + and whether they are to be printed in :func:`self.dev_get_data` and + :func:`self.dev_status`. This is the "low level" control interface, intended for the device dashboard. Example: - -------- - return dict( - delay = dict(type=float, rw=True, status=False), - time = dict(type=float, rw=True, status=False), - started = dict(type=bool, rw=True, status=True), - val = dict(type=int, rw=False, status=True), - ) + :: + + return dict( + delay = dict(type=float, rw=True, status=False, data=True), + time = dict(type=float, rw=True, status=False, data=True), + started = dict(type=bool, rw=True, status=True, data=False), + val = dict(type=int, rw=False, status=True, data=True), + ) + """ pass @@ -144,10 +147,14 @@ def dev_status(self, address: str, channel: int, **kwargs) -> dict[str, Any]: ) return ret - # @abstractmethod - # def dev_get_data(self, address: str, channel: int, **kwargs): - # """Get a data report from a Component""" - # pass + def dev_get_data(self, address: str, channel: int, **kwargs): + ret = {} + for k, v in self.attrs(address=address, channel=channel, **kwargs).items(): + if v.data: + ret[k] = self.dev_get_attr( + attr=k, address=address, channel=channel, **kwargs + ) + return ret @abstractmethod def tasks(self, address: str, channel: int, **kwargs) -> dict: @@ -157,9 +164,12 @@ def tasks(self, address: str, channel: int, **kwargs) -> dict: elements present in :func:`self.attrs`. Example: - return dict( - count = dict(time = dict(type=float), delay = dict(type=float), - ) + :: + + return dict( + count = dict(time = dict(type=float), delay = dict(type=float), + ) + """ pass @@ -178,10 +188,10 @@ def task_data(self, address: str, channel: int, **kwargs) -> xr.Dataset: """get any cached data for the current task on the component""" pass - # @abstractmethod - # def task_stop(self, address: str, channel: int) -> xr.Dataset: - # """stops the current task, making the component ready and returning any data""" - # pass + @abstractmethod + def task_stop(self, address: str, channel: int) -> xr.Dataset: + """stops the current task, making the component ready and returning any data""" + pass @abstractmethod def status(self) -> dict: From 6b340c338087f22c8ec443585f9c167e7cadf58b Mon Sep 17 00:00:00 2001 From: Peter Kraus Date: Fri, 31 May 2024 16:45:50 +0200 Subject: [PATCH 03/17] ModelInterface and Attr --- .../drivers/example_counter/__init__.py | 17 +++++----- src/tomato/models.py | 32 +++++++++++-------- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/src/tomato/drivers/example_counter/__init__.py b/src/tomato/drivers/example_counter/__init__.py index 43f7ce0a..d2b34b10 100644 --- a/src/tomato/drivers/example_counter/__init__.py +++ b/src/tomato/drivers/example_counter/__init__.py @@ -5,7 +5,7 @@ from functools import wraps from tomato.drivers.example_counter.counter import Counter -from tomato.models import Reply, DriverInterface +from tomato.models import Reply, ModelInterface from xarray import Dataset logger = logging.getLogger(__name__) @@ -24,7 +24,7 @@ def wrapper(self, **kwargs): return wrapper -class DriverInterface(DriverInterface): +class DriverInterface(ModelInterface): class DeviceInterface: dev: Counter conn: Connection @@ -40,10 +40,10 @@ def __init__(self): def attrs(self, **kwargs) -> dict: return dict( - delay=dict(type=float, rw=True, status=False, data=True), - time=dict(type=float, rw=True, status=False, data=True), - started=dict(type=bool, rw=True, status=True, data=False), - val=dict(type=int, rw=False, status=True, data=True), + delay=self.Attr(type=float, rw=True), + time=self.Attr(type=float, rw=True), + started=self.Attr(type=bool, rw=True, status=True), + val=self.Attr(type=int, status=True), ) def tasks(self, **kwargs) -> dict: @@ -70,7 +70,7 @@ def dev_set_attr(self, attr: str, val: Any, address: str, channel: int, **kwargs key = (address, channel) if attr in self.attrs(): params = self.attrs()[attr] - if params["rw"] and isinstance(val, params["type"]): + if params.rw and isinstance(val, params.type): self.devmap[key].conn.send(("set", attr, val)) @in_devmap @@ -183,4 +183,5 @@ def teardown(self): if __name__ == "__main__": - test = DriverInterface() + interface = DriverInterface() + print(f"{interface=}") diff --git a/src/tomato/models.py b/src/tomato/models.py index f120af70..1414f132 100644 --- a/src/tomato/models.py +++ b/src/tomato/models.py @@ -72,7 +72,14 @@ class Reply(BaseModel): data: Optional[Any] = None -class DriverInterface(metaclass=ABCMeta): +class ModelInterface(metaclass=ABCMeta): + class Attr(BaseModel): + """Class used to describe device attributes.""" + + type: type + rw: bool = False + status: bool = False + class DeviceInterface(metaclass=ABCMeta): """Class used to implement management of each individual device.""" @@ -106,11 +113,11 @@ def dev_teardown(self, address: str, channel: int, **kwargs: dict) -> None: pass @abstractmethod - def attrs(self, address: str, channel: int, **kwargs) -> dict: + def attrs(self, address: str, channel: int, **kwargs) -> dict[str, Attr]: """ Function that returns all gettable and settable attributes, their rw status, - and whether they are to be printed in :func:`self.dev_get_data` and - :func:`self.dev_status`. + and whether they are to be returned in :func:`self.dev_status`. All attrs are + returned by :func:`self.dev_get_data`. This is the "low level" control interface, intended for the device dashboard. @@ -118,10 +125,10 @@ def attrs(self, address: str, channel: int, **kwargs) -> dict: :: return dict( - delay = dict(type=float, rw=True, status=False, data=True), - time = dict(type=float, rw=True, status=False, data=True), - started = dict(type=bool, rw=True, status=True, data=False), - val = dict(type=int, rw=False, status=True, data=True), + delay = self.Attr(type=float, rw=True, status=False), + time = self.Attr(type=float, rw=True, status=False), + started = self.Attr(type=bool, rw=True, status=True), + val = self.Attr(type=int, rw=False, status=True), ) """ @@ -149,11 +156,10 @@ def dev_status(self, address: str, channel: int, **kwargs) -> dict[str, Any]: def dev_get_data(self, address: str, channel: int, **kwargs): ret = {} - for k, v in self.attrs(address=address, channel=channel, **kwargs).items(): - if v.data: - ret[k] = self.dev_get_attr( - attr=k, address=address, channel=channel, **kwargs - ) + for k in self.attrs(address=address, channel=channel, **kwargs).keys(): + ret[k] = self.dev_get_attr( + attr=k, address=address, channel=channel, **kwargs + ) return ret @abstractmethod From 532e0e0d9424676193f011a3320756e3bf09c9ef Mon Sep 17 00:00:00 2001 From: Peter Kraus Date: Fri, 31 May 2024 17:08:03 +0200 Subject: [PATCH 04/17] Fork out example_counter. --- pyproject.toml | 1 + src/tomato/daemon/driver.py | 9 +- src/tomato/drivers/__init__.py | 23 ++- .../drivers/example_counter/__init__.py | 187 ------------------ src/tomato/drivers/example_counter/counter.py | 62 ------ 5 files changed, 24 insertions(+), 258 deletions(-) delete mode 100644 src/tomato/drivers/example_counter/__init__.py delete mode 100644 src/tomato/drivers/example_counter/counter.py diff --git a/pyproject.toml b/pyproject.toml index 95678239..d0c9d652 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "h5netcdf >= 1.3", "xarray >= 2024.2", "pydantic ~= 1.0", + "tomato-example-counter @ git+https://github.com/dgbowl/tomato-example-counter.git", ] [project.optional-dependencies] diff --git a/src/tomato/daemon/driver.py b/src/tomato/daemon/driver.py index 84b9fc9c..b6165d1f 100644 --- a/src/tomato/daemon/driver.py +++ b/src/tomato/daemon/driver.py @@ -18,7 +18,7 @@ import zmq import psutil -import tomato.drivers +from tomato.drivers import driver_to_interface from tomato.models import Reply logger = logging.getLogger(__name__) @@ -96,12 +96,11 @@ def tomato_driver() -> None: logger.debug(f"{daemon=}") logger.info(f"attempting to spawn driver {args.driver!r}") - if not hasattr(tomato.drivers, args.driver): + Interface = driver_to_interface(args.driver) + if Interface is None: logger.critical(f"library of driver {args.driver!r} not found") return - - kwargs = dict(settings=daemon.drvs[args.driver].settings) - interface = getattr(tomato.drivers, args.driver).DriverInterface(**kwargs) + interface = Interface(settings=daemon.drvs[args.driver].settings) logger.info(f"registering devices in driver {args.driver!r}") for dev in daemon.devs.values(): diff --git a/src/tomato/drivers/__init__.py b/src/tomato/drivers/__init__.py index 0dee6233..10d0f3a4 100644 --- a/src/tomato/drivers/__init__.py +++ b/src/tomato/drivers/__init__.py @@ -1,9 +1,24 @@ """ Driver documentation goes here. """ +import importlib +import logging -from tomato.drivers import example_counter +from typing import Union +from tomato.models import ModelInterface -__all__ = [ - "example_counter", -] +logger = logging.getLogger(__name__) + +def driver_to_interface(drivername: str) -> Union[None, ModelInterface]: + modname = f"tomato_{drivername.replace('-', '_')}" + + try: + mod = importlib.import_module(modname) + except ModuleNotFoundError as e: + logger.critical("Error when loading 'DriverInteface': %s", e) + return None + else: + if hasattr(mod, "DriverInterface"): + return getattr(mod, "DriverInterface") + else: + return None \ No newline at end of file diff --git a/src/tomato/drivers/example_counter/__init__.py b/src/tomato/drivers/example_counter/__init__.py deleted file mode 100644 index d2b34b10..00000000 --- a/src/tomato/drivers/example_counter/__init__.py +++ /dev/null @@ -1,187 +0,0 @@ -import logging -from multiprocessing import Process, Pipe -from multiprocessing.connection import Connection -from typing import Any -from functools import wraps - -from tomato.drivers.example_counter.counter import Counter -from tomato.models import Reply, ModelInterface -from xarray import Dataset - -logger = logging.getLogger(__name__) - - -def in_devmap(func): - @wraps(func) - def wrapper(self, **kwargs): - address = kwargs.get("address") - channel = kwargs.get("channel") - if (address, channel) not in self.devmap: - msg = f"dev with address {address!r} and channel {channel} is unknown" - return Reply(success=False, msg=msg, data=self.devmap.keys()) - return func(self, **kwargs) - - return wrapper - - -class DriverInterface(ModelInterface): - class DeviceInterface: - dev: Counter - conn: Connection - proc: Process - - def __init__(self): - self.dev = Counter() - self.conn, conn = Pipe() - self.proc = Process(target=self.dev.run_counter, args=(conn,)) - - devmap: dict[tuple, DeviceInterface] - settings: dict - - def attrs(self, **kwargs) -> dict: - return dict( - delay=self.Attr(type=float, rw=True), - time=self.Attr(type=float, rw=True), - started=self.Attr(type=bool, rw=True, status=True), - val=self.Attr(type=int, status=True), - ) - - def tasks(self, **kwargs) -> dict: - return dict( - count=dict( - time=dict(type=float), - delay=dict(type=float), - ), - random=dict( - time=dict(type=float), - delay=dict(type=float), - min=dict(type=float), - max=dict(type=float), - ), - ) - - def dev_register(self, address: str, channel: int, **kwargs): - key = (address, channel) - self.devmap[key] = self.DeviceInterface() - self.devmap[key].proc.start() - - @in_devmap - def dev_set_attr(self, attr: str, val: Any, address: str, channel: int, **kwargs): - key = (address, channel) - if attr in self.attrs(): - params = self.attrs()[attr] - if params.rw and isinstance(val, params.type): - self.devmap[key].conn.send(("set", attr, val)) - - @in_devmap - def dev_get_attr(self, attr: str, address: str, channel: int, **kwargs): - key = (address, channel) - if attr in self.attrs(): - self.devmap[key].conn.send(("get", attr, None)) - return self.devmap[key].conn.recv() - - @in_devmap - def dev_status(self, address: str, channel: int, **kwargs): - key = (address, channel) - self.devmap[key].conn.send(("status", None, None)) - return self.devmap[key].conn.recv() - - @in_devmap - def task_start(self, address: str, channel: int, task: str, **kwargs): - if task not in self.tasks(): - return Reply( - success=False, - msg=f"unknown task {task!r} requested", - data=self.tasks(), - ) - - reqs = self.tasks()[task] - for k, v in reqs.items(): - if k not in kwargs and "default" not in v: - logger.critical("Somehow we're here") - logger.critical(f"{k=} {kwargs=}") - logger.critical(f"{v=}") - return Reply( - success=False, - msg=f"required parameter {k!r} missing", - data=reqs, - ) - val = kwargs.get(k, v.get("default")) - logger.critical(f"{k=} {val=}") - if not isinstance(val, v["type"]): - return Reply( - success=False, - msg=f"parameter {k!r} is wrong type", - data=reqs, - ) - self.dev_set_attr(attr=k, val=val, address=address, channel=channel) - self.dev_set_attr(attr="started", val=True, address=address, channel=channel) - return Reply( - success=True, - msg=f"task {task!r} started successfully", - data=kwargs, - ) - - @in_devmap - def task_status(self, address: str, channel: int): - started = self.dev_get_attr(attr="started", address=address, channel=channel) - if not started: - return Reply(success=True, msg="ready") - else: - return Reply(success=True, msg="running") - - @in_devmap - def task_stop(self, address: str, channel: int): - self.dev_set_attr(attr="started", val=False, address=address, channel=channel) - - ret = self.task_data(self, address, channel) - if ret.success: - return Reply(success=True, msg=f"task stopped, {ret.msg}", data=ret.data) - else: - return Reply(success=True, msg=f"task stopped, {ret.msg}") - - @in_devmap - def task_data(self, address: str, channel: int, **kwargs): - key = (address, channel) - self.devmap[key].conn.send(("data", None, None)) - data = self.devmap[key].conn.recv() - - if len(data) == 0: - return Reply(success=False, msg="found no new datapoints") - - data_vars = {} - for ii, item in enumerate(data): - for k, v in item.items(): - if k not in data_vars: - data_vars[k] = [None] * ii - data_vars[k].append(v) - for k in data_vars: - if k not in item: - data_vars[k].append(None) - - uts = {"uts": data_vars.pop("uts")} - data_vars = {k: ("uts", v) for k, v in data_vars.items()} - ds = Dataset(data_vars=data_vars, coords=uts) - return Reply(success=True, msg=f"found {len(data)} new datapoints", data=ds) - - def status(self): - devkeys = self.devmap.keys() - return Reply( - success=True, - msg=f"driver running with {len(devkeys)} devices", - data=dict(devkeys=devkeys), - ) - - def teardown(self): - for key, dev in self.devmap.items(): - dev.conn.send(("stop", None, None)) - dev.proc.join(1) - if dev.proc.is_alive(): - logger.error(f"device {key!r} is still alive") - else: - logger.debug(f"device {key!r} successfully closed") - - -if __name__ == "__main__": - interface = DriverInterface() - print(f"{interface=}") diff --git a/src/tomato/drivers/example_counter/counter.py b/src/tomato/drivers/example_counter/counter.py deleted file mode 100644 index 834361f1..00000000 --- a/src/tomato/drivers/example_counter/counter.py +++ /dev/null @@ -1,62 +0,0 @@ -import time -import math - - -class Counter: - delay: float - val: int - started: bool - started_at: float - time: float - end: bool - - def __init__(self, delay: float = 0.5): - self.val = 0 - self.delay = delay - self.started = False - self.started_at = None - self.time = None - self.end = False - - def run_counter(self, conn): - t0 = time.perf_counter() - data = [] - while True: - tN = time.perf_counter() - - if self.started: - if self.started_at is None: - self.started_at = tN - t0 = tN - self.val = math.floor(tN - self.started_at) - if tN - t0 > self.delay: - data.append(dict(uts=tN, val=self.val)) - t0 += self.delay - if self.time is not None and tN - self.started_at > self.time: - self.started = False - self.started_at = None - - cmd = None - if conn.poll(1e-3): - cmd, attr, val = conn.recv() - - if cmd == "set": - if attr == "delay": - self.delay = val - elif attr == "time": - self.time = val - elif attr == "started": - self.started = val - data = [] - elif cmd == "get": - if hasattr(self, attr): - conn.send(getattr(self, attr)) - elif cmd == "stop": - break - elif cmd == "status": - conn.send(self.val) - elif cmd == "data": - conn.send(data) - data = [] - else: - time.sleep(1e-3) From b37491439ffdd2449cd8e2849ed17fb220340418 Mon Sep 17 00:00:00 2001 From: Peter Kraus Date: Fri, 31 May 2024 17:17:07 +0200 Subject: [PATCH 05/17] ruff --- src/tomato/drivers/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/tomato/drivers/__init__.py b/src/tomato/drivers/__init__.py index 10d0f3a4..8e740868 100644 --- a/src/tomato/drivers/__init__.py +++ b/src/tomato/drivers/__init__.py @@ -1,6 +1,7 @@ """ Driver documentation goes here. """ + import importlib import logging @@ -9,6 +10,7 @@ logger = logging.getLogger(__name__) + def driver_to_interface(drivername: str) -> Union[None, ModelInterface]: modname = f"tomato_{drivername.replace('-', '_')}" @@ -21,4 +23,4 @@ def driver_to_interface(drivername: str) -> Union[None, ModelInterface]: if hasattr(mod, "DriverInterface"): return getattr(mod, "DriverInterface") else: - return None \ No newline at end of file + return None From 01e2e131c45e8b7b3633a7cd5b606a6ecbbcc08e Mon Sep 17 00:00:00 2001 From: Peter Kraus Date: Fri, 31 May 2024 20:58:57 +0200 Subject: [PATCH 06/17] Fix pages & ruff. --- docs/source/index.rst | 1 - src/tomato/models.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index f048e51b..17ecceb1 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -42,7 +42,6 @@ sustainable batteries of the future. :maxdepth: 1 :caption: tomato driver library - apidoc/tomato.drivers.example_counter .. toctree:: :maxdepth: 1 diff --git a/src/tomato/models.py b/src/tomato/models.py index 1414f132..e300ff17 100644 --- a/src/tomato/models.py +++ b/src/tomato/models.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field -from typing import Optional, Any, Mapping, Sequence, Literal +from typing import Optional, Any, Mapping, Sequence, Literal, TypeVar from pathlib import Path from abc import ABCMeta, abstractmethod import xarray as xr @@ -76,7 +76,7 @@ class ModelInterface(metaclass=ABCMeta): class Attr(BaseModel): """Class used to describe device attributes.""" - type: type + type: TypeVar("T") rw: bool = False status: bool = False From c7e1b604e7375f9a8d330229298bf7e7ad5bf46b Mon Sep 17 00:00:00 2001 From: Peter Kraus Date: Sat, 1 Jun 2024 09:27:35 +0200 Subject: [PATCH 07/17] Refactor into versioned driverinterface --- src/tomato/driverinterface_1_0/__init__.py | 148 ++++++++++++++++++ .../{drivers/__init__.py => drivers.py} | 0 src/tomato/models.py | 147 +---------------- 3 files changed, 150 insertions(+), 145 deletions(-) create mode 100644 src/tomato/driverinterface_1_0/__init__.py rename src/tomato/{drivers/__init__.py => drivers.py} (100%) diff --git a/src/tomato/driverinterface_1_0/__init__.py b/src/tomato/driverinterface_1_0/__init__.py new file mode 100644 index 00000000..aa3fa87f --- /dev/null +++ b/src/tomato/driverinterface_1_0/__init__.py @@ -0,0 +1,148 @@ +from abc import ABCMeta, abstractmethod +import xarray as xr +from typing import TypeVar, Any, Literal +from pydantic import BaseModel + + +class ModelInterface(metaclass=ABCMeta): + version: Literal = "1.0" + + class Attr(BaseModel): + """Class used to describe device attributes.""" + + type: TypeVar("T") + rw: bool = False + status: bool = False + + class DeviceInterface(metaclass=ABCMeta): + """Class used to implement management of each individual device.""" + + pass + + devmap: dict[tuple, DeviceInterface] + """Map of registered devices, the tuple keys are components = (address, channel)""" + + settings: dict[str, str] + """A settings map to contain driver-specific settings such as `dllpath` for BioLogic""" + + def __init__(self, settings=None): + self.devmap = {} + self.settings = settings if settings is not None else {} + + def dev_register(self, address: str, channel: int, **kwargs: dict) -> None: + """ + Register a Device and its Component in this DriverInterface, creating a + :obj:`self.DeviceInterface` object in the :obj:`self.devmap` if necessary, or + updating existing channels in :obj:`self.devmap`. + """ + self.devmap[(address, channel)] = self.DeviceInterface(**kwargs) + + def dev_teardown(self, address: str, channel: int, **kwargs: dict) -> None: + """ + Emergency stop function. Set the device into a documented, safe state. + + The function is to be only called in case of critical errors, not as part of + normal operation. + """ + pass + + @abstractmethod + def attrs(self, address: str, channel: int, **kwargs) -> dict[str, Attr]: + """ + Function that returns all gettable and settable attributes, their rw status, + and whether they are to be returned in :func:`self.dev_status`. All attrs are + returned by :func:`self.dev_get_data`. + + This is the "low level" control interface, intended for the device dashboard. + + Example: + :: + + return dict( + delay = self.Attr(type=float, rw=True, status=False), + time = self.Attr(type=float, rw=True, status=False), + started = self.Attr(type=bool, rw=True, status=True), + val = self.Attr(type=int, rw=False, status=True), + ) + + """ + pass + + @abstractmethod + def dev_set_attr(self, attr: str, val: Any, address: str, channel: int, **kwargs): + """Set the value of a read-write attr on a Component""" + pass + + @abstractmethod + def dev_get_attr(self, attr: str, address: str, channel: int, **kwargs): + """Get the value of any attr from a Component""" + pass + + def dev_status(self, address: str, channel: int, **kwargs) -> dict[str, Any]: + """Get a status report from a Component""" + ret = {} + for k, v in self.attrs(address=address, channel=channel, **kwargs).items(): + if v.status: + ret[k] = self.dev_get_attr( + attr=k, address=address, channel=channel, **kwargs + ) + return ret + + def dev_get_data(self, address: str, channel: int, **kwargs): + ret = {} + for k in self.attrs(address=address, channel=channel, **kwargs).keys(): + ret[k] = self.dev_get_attr( + attr=k, address=address, channel=channel, **kwargs + ) + return ret + + @abstractmethod + def tasks(self, address: str, channel: int, **kwargs) -> dict: + """ + Function that returns all tasks that can be submitted to the Device. This + implements the driver specific language. Each task in tasks can only contain + elements present in :func:`self.attrs`. + + Example: + :: + + return dict( + count = dict(time = dict(type=float), delay = dict(type=float), + ) + + """ + pass + + @abstractmethod + def task_start(self, address: str, channel: int, task: str, **kwargs) -> None: + """start a task on a (ready) component""" + pass + + @abstractmethod + def task_status(self, address: str, channel: int) -> Literal["running", "ready"]: + """check task status of the component""" + pass + + @abstractmethod + def task_data(self, address: str, channel: int, **kwargs) -> xr.Dataset: + """get any cached data for the current task on the component""" + pass + + @abstractmethod + def task_stop(self, address: str, channel: int) -> xr.Dataset: + """stops the current task, making the component ready and returning any data""" + pass + + @abstractmethod + def status(self) -> dict: + """return status info of the driver""" + pass + + @abstractmethod + def teardown(self) -> None: + """ + Stop all tasks, tear down all devices, close all processes. + + Users can assume the devices are put in a safe state (valves closed, power off). + """ + pass diff --git a/src/tomato/drivers/__init__.py b/src/tomato/drivers.py similarity index 100% rename from src/tomato/drivers/__init__.py rename to src/tomato/drivers.py diff --git a/src/tomato/models.py b/src/tomato/models.py index e300ff17..6c50406d 100644 --- a/src/tomato/models.py +++ b/src/tomato/models.py @@ -1,8 +1,7 @@ from pydantic import BaseModel, Field -from typing import Optional, Any, Mapping, Sequence, Literal, TypeVar +from typing import Optional, Any, Mapping, Sequence, Literal from pathlib import Path -from abc import ABCMeta, abstractmethod -import xarray as xr +from tomato.driverinterface_1_0 import ModelInterface as ModelInterface class Driver(BaseModel): @@ -70,145 +69,3 @@ class Reply(BaseModel): success: bool msg: str data: Optional[Any] = None - - -class ModelInterface(metaclass=ABCMeta): - class Attr(BaseModel): - """Class used to describe device attributes.""" - - type: TypeVar("T") - rw: bool = False - status: bool = False - - class DeviceInterface(metaclass=ABCMeta): - """Class used to implement management of each individual device.""" - - pass - - devmap: dict[tuple, DeviceInterface] - """Map of registered devices, the tuple keys are components = (address, channel)""" - - settings: dict[str, str] - """A settings map to contain driver-specific settings such as `dllpath` for BioLogic""" - - def __init__(self, settings=None): - self.devmap = {} - self.settings = settings if settings is not None else {} - - def dev_register(self, address: str, channel: int, **kwargs: dict) -> None: - """ - Register a Device and its Component in this DriverInterface, creating a - :obj:`self.DeviceInterface` object in the :obj:`self.devmap` if necessary, or - updating existing channels in :obj:`self.devmap`. - """ - self.devmap[(address, channel)] = self.DeviceInterface(**kwargs) - - def dev_teardown(self, address: str, channel: int, **kwargs: dict) -> None: - """ - Emergency stop function. Set the device into a documented, safe state. - - The function is to be only called in case of critical errors, not as part of - normal operation. - """ - pass - - @abstractmethod - def attrs(self, address: str, channel: int, **kwargs) -> dict[str, Attr]: - """ - Function that returns all gettable and settable attributes, their rw status, - and whether they are to be returned in :func:`self.dev_status`. All attrs are - returned by :func:`self.dev_get_data`. - - This is the "low level" control interface, intended for the device dashboard. - - Example: - :: - - return dict( - delay = self.Attr(type=float, rw=True, status=False), - time = self.Attr(type=float, rw=True, status=False), - started = self.Attr(type=bool, rw=True, status=True), - val = self.Attr(type=int, rw=False, status=True), - ) - - """ - pass - - @abstractmethod - def dev_set_attr(self, attr: str, val: Any, address: str, channel: int, **kwargs): - """Set the value of a read-write attr on a Component""" - pass - - @abstractmethod - def dev_get_attr(self, attr: str, address: str, channel: int, **kwargs): - """Get the value of any attr from a Component""" - pass - - def dev_status(self, address: str, channel: int, **kwargs) -> dict[str, Any]: - """Get a status report from a Component""" - ret = {} - for k, v in self.attrs(address=address, channel=channel, **kwargs).items(): - if v.status: - ret[k] = self.dev_get_attr( - attr=k, address=address, channel=channel, **kwargs - ) - return ret - - def dev_get_data(self, address: str, channel: int, **kwargs): - ret = {} - for k in self.attrs(address=address, channel=channel, **kwargs).keys(): - ret[k] = self.dev_get_attr( - attr=k, address=address, channel=channel, **kwargs - ) - return ret - - @abstractmethod - def tasks(self, address: str, channel: int, **kwargs) -> dict: - """ - Function that returns all tasks that can be submitted to the Device. This - implements the driver specific language. Each task in tasks can only contain - elements present in :func:`self.attrs`. - - Example: - :: - - return dict( - count = dict(time = dict(type=float), delay = dict(type=float), - ) - - """ - pass - - @abstractmethod - def task_start(self, address: str, channel: int, task: str, **kwargs) -> None: - """start a task on a (ready) component""" - pass - - @abstractmethod - def task_status(self, address: str, channel: int) -> Literal["running", "ready"]: - """check task status of the component""" - pass - - @abstractmethod - def task_data(self, address: str, channel: int, **kwargs) -> xr.Dataset: - """get any cached data for the current task on the component""" - pass - - @abstractmethod - def task_stop(self, address: str, channel: int) -> xr.Dataset: - """stops the current task, making the component ready and returning any data""" - pass - - @abstractmethod - def status(self) -> dict: - """return status info of the driver""" - pass - - @abstractmethod - def teardown(self) -> None: - """ - Stop all tasks, tear down all devices, close all processes. - - Users can assume the devices are put in a safe state (valves closed, power off). - """ - pass From 6f23f82b3a45eeed6275df8f70186f2d1422dde6 Mon Sep 17 00:00:00 2001 From: Peter Kraus Date: Sun, 7 Jul 2024 20:02:38 +0200 Subject: [PATCH 08/17] Tests pass. --- pyproject.toml | 12 +- src/tomato/daemon/driver.py | 7 +- src/tomato/daemon/job.py | 47 ++--- src/tomato/driverinterface_1_0/__init__.py | 215 +++++++++++++++------ src/tomato/drivers.py | 2 +- src/tomato/ketchup/__init__.py | 25 ++- src/tomato/models.py | 1 - src/tomato/tomato/__init__.py | 10 +- 8 files changed, 223 insertions(+), 96 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d0c9d652..54c09684 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,11 +28,11 @@ dependencies = [ "toml >= 0.10", "pyyaml >= 6.0", "psutil >= 5.9", - "dgbowl_schemas >= 108", + "dgbowl_schemas @ git+https://github.com/dgbowl/dgbowl-schemas.git@Payload_1.0", "pyzmq >= 25.1", "h5netcdf >= 1.3", "xarray >= 2024.2", - "pydantic ~= 1.0", + "pydantic >= 2.0", "tomato-example-counter @ git+https://github.com/dgbowl/tomato-example-counter.git", ] @@ -41,7 +41,7 @@ testing = ["pytest"] docs = [ "sphinx ~= 7.2", "sphinx-rtd-theme ~= 1.3.0", - "autodoc-pydantic ~= 1.9.0", + "autodoc-pydantic ~= 2.1", "sphinxcontrib-mermaid ~= 0.9.2", ] @@ -63,4 +63,8 @@ enabled = true dev_template = "{tag}.dev{ccount}" dirty_template = "{tag}.dev{ccount}" -[tool.ruff] \ No newline at end of file +[tool.ruff] + +[tool.pytest.ini_options] +log_cli = false +log_cli_level = "DEBUG" \ No newline at end of file diff --git a/src/tomato/daemon/driver.py b/src/tomato/daemon/driver.py index b6165d1f..f03df824 100644 --- a/src/tomato/daemon/driver.py +++ b/src/tomato/daemon/driver.py @@ -56,7 +56,7 @@ def tomato_driver() -> None: parser.add_argument( "--verbosity", help="Verbosity of the tomato-driver.", - default=logging.INFO, + default=logging.DEBUG, type=int, ) parser.add_argument( @@ -174,7 +174,6 @@ def tomato_driver() -> None: ret = interface.task_start(**msg["params"]) elif msg["cmd"] == "task_data": ret = interface.task_data(**msg["params"]) - logger.debug(f"{ret=}") rep.send_pyobj(ret) if status == "stop": break @@ -187,7 +186,8 @@ def tomato_driver() -> None: def spawn_tomato_driver(port: int, driver: str, req: zmq.Socket, verbosity: int): - cmd = ["tomato-driver", "--port", str(port), "--verbosity", str(verbosity), driver] + # cmd = ["tomato-driver", "--port", str(port), "--verbosity", str(verbosity), driver] + cmd = ["tomato-driver", "--port", str(port), driver] if psutil.WINDOWS: cfs = subprocess.CREATE_NO_WINDOW cfs |= subprocess.CREATE_NEW_PROCESS_GROUP @@ -276,7 +276,6 @@ def manager(port: int, timeout: int = 1000): for driver in daemon.drvs.values(): logger.debug(f"stopping driver {driver.name!r} on port {driver.port}") ret = stop_tomato_driver(driver.port, context) - logger.debug(f"{ret=}") if ret.success: logger.info(f"stopped driver {driver.name!r}") else: diff --git a/src/tomato/daemon/job.py b/src/tomato/daemon/job.py index b458b588..2544a226 100644 --- a/src/tomato/daemon/job.py +++ b/src/tomato/daemon/job.py @@ -28,18 +28,20 @@ from tomato.daemon.io import merge_netcdfs, data_to_pickle from tomato.models import Pipeline, Daemon, Component, Device, Driver +from dgbowl_schemas.tomato import to_payload +from dgbowl_schemas.tomato.payload import Payload logger = logging.getLogger(__name__) def find_matching_pipelines(daemon: Daemon, method: list[dict]) -> list[str]: - req_names = set([item.device for item in method]) - req_capabs = set([item.technique for item in method]) + req_tags = set([item.component_tag for item in method]) + req_capabs = set([item.technique_name for item in method]) candidates = [] for pip in daemon.pips.values(): dnames = set([comp.role for comp in pip.devs.values()]) - if req_names.intersection(dnames) == req_names: + if req_tags.intersection(dnames) == req_tags: candidates.append(pip) matched = [] @@ -189,7 +191,7 @@ def action_queued_jobs(daemon, matched, req): jpath = root / "jobdata.json" jobargs = { "pipeline": pip.dict(), - "payload": job.payload.dict(), + "payload": job.payload.model_dump(), "devices": {dname: dev.dict() for dname, dev in daemon.devs.items()}, "job": dict(id=job.id, path=str(root)), } @@ -325,8 +327,8 @@ def tomato_job() -> None: with args.jobfile.open() as infile: jsdata = json.load(infile) - payload = jsdata["payload"] - ready = payload["tomato"].get("unlock_when_done", False) + payload = to_payload(**jsdata["payload"]) + pip = jsdata["pipeline"]["name"] jobid = jsdata["job"]["id"] jobpath = Path(jsdata["job"]["path"]).resolve() @@ -339,8 +341,10 @@ def tomato_job() -> None: ) logger = logging.getLogger(__name__) - tomato = payload.get("tomato", {}) - verbosity = tomato.get("verbosity", "INFO") + logger.debug(f"{payload=}") + + ready = payload.settings.unlock_when_done + verbosity = payload.settings.verbosity loglevel = logging._checkLevel(verbosity) logger.debug("setting logger verbosity to '%s'", verbosity) logger.setLevel(loglevel) @@ -364,15 +368,15 @@ def tomato_job() -> None: params = dict(pid=pid, status="r", executed_at=str(datetime.now(timezone.utc))) lazy_pirate(pyobj=dict(cmd="job", id=jobid, params=params), **pkwargs) - output = tomato["output"] - outpath = Path(output["path"]) + output = payload.settings.output + outpath = Path(output.path) logger.debug(f"output folder is {outpath}") if outpath.exists(): assert outpath.is_dir() else: logger.debug("path does not exist, creating") os.makedirs(outpath) - prefix = f"results.{jobid}" if output["prefix"] is None else output["prefix"] + prefix = f"results.{jobid}" if output.prefix is None else output.prefix respath = outpath / f"{prefix}.nc" snappath = outpath / f"snapshot.{jobid}.nc" params = dict(respath=str(respath), snappath=str(snappath), jobpath=str(jobpath)) @@ -446,7 +450,7 @@ def job_thread( if ret.success and ret.msg == "ready": break - req.send_pyobj(dict(cmd="task_start", params={**task, **kwargs})) + req.send_pyobj(dict(cmd="task_start", params={"task": task, **kwargs})) ret = req.recv_pyobj() logger.debug(f"{ret=}") @@ -475,7 +479,7 @@ def job_thread( def job_main_loop( context: zmq.Context, port: int, - payload: dict, + payload: Payload, pipname: str, jobpath: Path, snappath: Path, @@ -505,13 +509,10 @@ def job_main_loop( # collate steps by role plan = {} - for step in payload["method"]: - if step["device"] not in plan: - plan[step["device"]] = [] - task = {k: v for k, v in step.items()} - del task["device"] - task["task"] = task.pop("technique") - plan[step["device"]].append(task) + for step in payload.method: + if step.component_tag not in plan: + plan[step.component_tag] = [] + plan[step.component_tag].append(step) logger.debug(f"{plan=}") # distribute plan into threads @@ -531,15 +532,15 @@ def job_main_loop( threads[role].start() # wait until threads join or we're killed - snapshot = payload["tomato"].get("snapshot", None) + snapshot = payload.settings.snapshot t0 = time.perf_counter() while True: logger.debug("tick") tN = time.perf_counter() - if snapshot is not None and tN - t0 > snapshot["frequency"]: + if snapshot is not None and tN - t0 > snapshot.frequency: logger.debug("creating snapshot") merge_netcdfs(jobpath, snappath) - t0 += snapshot["frequency"] + t0 += snapshot.frequency joined = [proc.is_alive() is False for proc in threads.values()] if all(joined): break diff --git a/src/tomato/driverinterface_1_0/__init__.py b/src/tomato/driverinterface_1_0/__init__.py index aa3fa87f..53ee3869 100644 --- a/src/tomato/driverinterface_1_0/__init__.py +++ b/src/tomato/driverinterface_1_0/__init__.py @@ -1,7 +1,29 @@ from abc import ABCMeta, abstractmethod -import xarray as xr from typing import TypeVar, Any, Literal from pydantic import BaseModel +from threading import Thread, currentThread +from queue import Queue +from tomato.models import Reply +from dgbowl_schemas.tomato.payload import Task +import logging +from functools import wraps +from xarray import Dataset + + +logger = logging.getLogger(__name__) + + +def in_devmap(func): + @wraps(func) + def wrapper(self, **kwargs): + address = kwargs.get("address") + channel = kwargs.get("channel") + if (address, channel) not in self.devmap: + msg = f"dev with address {address!r} and channel {channel} is unknown" + return Reply(success=False, msg=msg, data=self.devmap.keys()) + return func(self, **kwargs) + + return wrapper class ModelInterface(metaclass=ABCMeta): @@ -15,9 +37,45 @@ class Attr(BaseModel): status: bool = False class DeviceInterface(metaclass=ABCMeta): - """Class used to implement management of each individual device.""" - - pass + driver: object + data: list + status: dict + key: tuple + thread: Thread + task_list: Queue + running: bool + + def __init__(self, driver, key, **kwargs): + self.driver = driver + self.key = key + self.task_list = Queue() + self.thread = Thread(target=self._worker_wrapper, daemon=True) + self.data = [] + self.status = {} + self.running = False + + def run(self): + self.thread.do_run = True + self.thread.start() + self.running = True + + def _worker_wrapper(self): + thread = currentThread() + task = self.task_list.get() + + self.task_runner(task, thread) + + self.task_list.task_done() + self.running = False + self.thread = Thread(target=self._worker_wrapper, daemon=True) + + @abstractmethod + def task_runner(task: Task, thread: Thread): + pass + + def CreateDeviceInterface(self, key, **kwargs): + """Factory function which passes DriverInterface to the DeviceInterface""" + return self.DeviceInterface(self, key, **kwargs) devmap: dict[tuple, DeviceInterface] """Map of registered devices, the tuple keys are components = (address, channel)""" @@ -35,7 +93,8 @@ def dev_register(self, address: str, channel: int, **kwargs: dict) -> None: :obj:`self.DeviceInterface` object in the :obj:`self.devmap` if necessary, or updating existing channels in :obj:`self.devmap`. """ - self.devmap[(address, channel)] = self.DeviceInterface(**kwargs) + key = (address, channel) + self.devmap[(address, channel)] = self.CreateDeviceInterface(key, **kwargs) def dev_teardown(self, address: str, channel: int, **kwargs: dict) -> None: """ @@ -68,25 +127,107 @@ def attrs(self, address: str, channel: int, **kwargs) -> dict[str, Attr]: """ pass - @abstractmethod + @in_devmap def dev_set_attr(self, attr: str, val: Any, address: str, channel: int, **kwargs): - """Set the value of a read-write attr on a Component""" - pass + key = (address, channel) + if attr in self.attrs(): + params = self.attrs()[attr] + if params.rw and isinstance(val, params.type): + self.devmap[key].status[attr] = val - @abstractmethod + @in_devmap def dev_get_attr(self, attr: str, address: str, channel: int, **kwargs): - """Get the value of any attr from a Component""" - pass + key = (address, channel) + if attr in self.attrs(address=address, channel=channel): + return self.devmap[key].status[attr] + + @in_devmap + def dev_status(self, address: str, channel: int, **kwargs): + key = (address, channel) + running = self.devmap[key].running + return Reply( + success=True, + msg=f"component {key} is{' ' if running else ' not ' }running", + data=running, + ) + + @in_devmap + def task_start(self, address: str, channel: int, task: Task, **kwargs): + if task.technique_name not in self.tasks(address=address, channel=channel): + return Reply( + success=False, + msg=f"unknown task {task.technique_name!r} requested", + data=self.tasks(), + ) - def dev_status(self, address: str, channel: int, **kwargs) -> dict[str, Any]: - """Get a status report from a Component""" - ret = {} - for k, v in self.attrs(address=address, channel=channel, **kwargs).items(): - if v.status: - ret[k] = self.dev_get_attr( - attr=k, address=address, channel=channel, **kwargs - ) - return ret + key = (address, channel) + self.devmap[key].task_list.put(task) + self.devmap[key].run() + return Reply( + success=True, + msg=f"task {task!r} started successfully", + data=task, + ) + + @in_devmap + def task_status(self, address: str, channel: int): + key = (address, channel) + started = self.devmap[key].running + if not started: + return Reply(success=True, msg="ready") + else: + return Reply(success=True, msg="running") + + @in_devmap + def task_stop(self, address: str, channel: int): + self.dev_set_attr(attr="started", val=False, address=address, channel=channel) + + ret = self.task_data(self, address, channel) + if ret.success: + return Reply(success=True, msg=f"task stopped, {ret.msg}", data=ret.data) + else: + return Reply(success=True, msg=f"task stopped, {ret.msg}") + + @in_devmap + def task_data(self, address: str, channel: int, **kwargs): + key = (address, channel) + data = self.devmap[key].data + self.devmap[key].data = [] + + if len(data) == 0: + return Reply(success=False, msg="found no new datapoints") + + data_vars = {} + for ii, item in enumerate(data): + for k, v in item.items(): + if k not in data_vars: + data_vars[k] = [None] * ii + data_vars[k].append(v) + for k in data_vars: + if k not in item: + data_vars[k].append(None) + + uts = {"uts": data_vars.pop("uts")} + data_vars = {k: ("uts", v) for k, v in data_vars.items()} + ds = Dataset(data_vars=data_vars, coords=uts) + return Reply(success=True, msg=f"found {len(data)} new datapoints", data=ds) + + def status(self): + devkeys = self.devmap.keys() + return Reply( + success=True, + msg=f"driver running with {len(devkeys)} devices", + data=dict(devkeys=devkeys), + ) + + def teardown(self): + for key, dev in self.devmap.items(): + dev.thread.do_run = False + dev.thread.join(1) + if dev.thread.is_alive(): + logger.error(f"device {key!r} is still alive") + else: + logger.debug(f"device {key!r} successfully closed") def dev_get_data(self, address: str, channel: int, **kwargs): ret = {} @@ -112,37 +253,3 @@ def tasks(self, address: str, channel: int, **kwargs) -> dict: """ pass - - @abstractmethod - def task_start(self, address: str, channel: int, task: str, **kwargs) -> None: - """start a task on a (ready) component""" - pass - - @abstractmethod - def task_status(self, address: str, channel: int) -> Literal["running", "ready"]: - """check task status of the component""" - pass - - @abstractmethod - def task_data(self, address: str, channel: int, **kwargs) -> xr.Dataset: - """get any cached data for the current task on the component""" - pass - - @abstractmethod - def task_stop(self, address: str, channel: int) -> xr.Dataset: - """stops the current task, making the component ready and returning any data""" - pass - - @abstractmethod - def status(self) -> dict: - """return status info of the driver""" - pass - - @abstractmethod - def teardown(self) -> None: - """ - Stop all tasks, tear down all devices, close all processes. - - Users can assume the devices are put in a safe state (valves closed, power off). - """ - pass diff --git a/src/tomato/drivers.py b/src/tomato/drivers.py index 8e740868..0d7d8401 100644 --- a/src/tomato/drivers.py +++ b/src/tomato/drivers.py @@ -6,7 +6,7 @@ import logging from typing import Union -from tomato.models import ModelInterface +from tomato.driverinterface_1_0 import ModelInterface logger = logging.getLogger(__name__) diff --git a/src/tomato/ketchup/__init__.py b/src/tomato/ketchup/__init__.py index 912bf4d4..1ab996e8 100644 --- a/src/tomato/ketchup/__init__.py +++ b/src/tomato/ketchup/__init__.py @@ -21,6 +21,7 @@ from datetime import datetime, timezone import yaml import zmq +from packaging.version import Version from dgbowl_schemas.tomato import to_payload from tomato.daemon.io import merge_netcdfs @@ -28,6 +29,8 @@ log = logging.getLogger(__name__) +__latest_payload__ = "1.0" + def submit( *, @@ -95,15 +98,23 @@ def submit( return Reply(success=False, msg="payload must be a yaml or a json file") payload = to_payload(**pldict) - if payload.tomato.output.path is None: + maxver = Version(__latest_payload__) + while hasattr(payload, "update"): + temp = payload.update() + if hasattr(temp, "version"): + if Version(temp.version) > maxver: + break + payload = temp + print(f"{payload=}") + + if payload.settings.output.path is None: cwd = str(Path().resolve()) log.info(f"Output path not set. Setting output path to {cwd}") - payload.tomato.output.path = cwd - if hasattr(payload.tomato, "snapshot"): - if payload.tomato.snapshot is not None and payload.tomato.snapshot.path is None: - cwd = str(Path().resolve()) - log.info(f"Snapshot path not set. Setting output path to {cwd}") - payload.tomato.snapshot.path = cwd + payload.settings.output.path = cwd + if payload.settings.snapshot is not None and payload.settings.snapshot.path is None: + cwd = str(Path().resolve()) + log.info(f"Snapshot path not set. Setting output path to {cwd}") + payload.settings.snapshot.path = cwd log.debug("queueing 'payload' into 'queue'") req = context.socket(zmq.REQ) diff --git a/src/tomato/models.py b/src/tomato/models.py index 6c50406d..e27d99cb 100644 --- a/src/tomato/models.py +++ b/src/tomato/models.py @@ -1,7 +1,6 @@ from pydantic import BaseModel, Field from typing import Optional, Any, Mapping, Sequence, Literal from pathlib import Path -from tomato.driverinterface_1_0 import ModelInterface as ModelInterface class Driver(BaseModel): diff --git a/src/tomato/tomato/__init__.py b/src/tomato/tomato/__init__.py index 5b654c03..90d214f3 100644 --- a/src/tomato/tomato/__init__.py +++ b/src/tomato/tomato/__init__.py @@ -327,8 +327,8 @@ def reload( if not stat.success: return stat daemon = stat.data - logger.critical(f"{daemon.status=}") - logger.critical(f"{daemon.pips=}") + logger.debug(f"{daemon.status=}") + logger.debug(f"{daemon.pips=}") req = context.socket(zmq.REQ) req.connect(f"tcp://127.0.0.1:{port}") if daemon.status == "bootstrap": @@ -343,6 +343,7 @@ def reload( ) ) rep = req.recv_pyobj() + logger.debug(rep) elif daemon.status == "running": retries = 0 while True: @@ -361,6 +362,7 @@ def reload( # check changes in driver settings for drv in drvs.values(): + logger.debug(f"{drv=}") if drv.settings != daemon.drvs[drv.name].settings: ret = _updater( context, daemon.drvs[drv.name].port, "settings", drv.settings @@ -375,6 +377,7 @@ def reload( # check changes in devices for dev in devs.values(): + logger.debug(f"{dev=}") if ( dev.name not in daemon.devs or dev.channels != daemon.devs[dev.name].channels @@ -385,9 +388,11 @@ def reload( channel=channel, capabilities=dev.capabilities, ) + logger.debug(f"{params=}") ret = _updater( context, daemon.drvs[drv.name].port, "dev_register", params ) + logger.debug(f"{ret=}") if ret.success is False: return ret params = dev.dict() @@ -401,6 +406,7 @@ def reload( logger.error("removing devices not yet implemented") # check changes in pipelines for pip in pips.values(): + logger.debug(f"{pip=}") if pip.name not in daemon.pips: ret = _updater(context, port, "pipeline", pip.dict()) if ret.success is False: From af20439813fa5ead490ee8ff72d2c9cffb68cf87 Mon Sep 17 00:00:00 2001 From: Peter Kraus Date: Wed, 10 Jul 2024 14:54:31 +0200 Subject: [PATCH 09/17] Refactor driver interface + warnings --- src/tomato/daemon/job.py | 4 +- src/tomato/driverinterface_1_0/__init__.py | 209 ++++++++++++--------- src/tomato/drivers.py | 4 +- src/tomato/tomato/__init__.py | 4 +- 4 files changed, 127 insertions(+), 94 deletions(-) diff --git a/src/tomato/daemon/job.py b/src/tomato/daemon/job.py index 2544a226..ee774ad9 100644 --- a/src/tomato/daemon/job.py +++ b/src/tomato/daemon/job.py @@ -190,9 +190,9 @@ def action_queued_jobs(daemon, matched, req): jpath = root / "jobdata.json" jobargs = { - "pipeline": pip.dict(), + "pipeline": pip.model_dump(), "payload": job.payload.model_dump(), - "devices": {dname: dev.dict() for dname, dev in daemon.devs.items()}, + "devices": {dn: dev.model_dump() for dn, dev in daemon.devs.items()}, "job": dict(id=job.id, path=str(root)), } diff --git a/src/tomato/driverinterface_1_0/__init__.py b/src/tomato/driverinterface_1_0/__init__.py index 53ee3869..b48db762 100644 --- a/src/tomato/driverinterface_1_0/__init__.py +++ b/src/tomato/driverinterface_1_0/__init__.py @@ -1,5 +1,5 @@ from abc import ABCMeta, abstractmethod -from typing import TypeVar, Any, Literal +from typing import TypeVar, Any from pydantic import BaseModel from threading import Thread, currentThread from queue import Queue @@ -8,7 +8,8 @@ import logging from functools import wraps from xarray import Dataset - +from collections import defaultdict +import time logger = logging.getLogger(__name__) @@ -26,20 +27,23 @@ def wrapper(self, **kwargs): return wrapper -class ModelInterface(metaclass=ABCMeta): - version: Literal = "1.0" +T = TypeVar("T") + + +class Attr(BaseModel): + """Class used to describe device attributes.""" + + type: T + rw: bool = False + status: bool = False - class Attr(BaseModel): - """Class used to describe device attributes.""" - type: TypeVar("T") - rw: bool = False - status: bool = False +class DriverInterface(metaclass=ABCMeta): + version: str = "1.0" class DeviceInterface(metaclass=ABCMeta): driver: object - data: list - status: dict + data: dict[str, list] key: tuple thread: Thread task_list: Queue @@ -49,9 +53,8 @@ def __init__(self, driver, key, **kwargs): self.driver = driver self.key = key self.task_list = Queue() - self.thread = Thread(target=self._worker_wrapper, daemon=True) - self.data = [] - self.status = {} + self.thread = Thread(target=self.task_runner, daemon=True) + self.data = defaultdict(list) self.running = False def run(self): @@ -59,20 +62,65 @@ def run(self): self.thread.start() self.running = True - def _worker_wrapper(self): + def task_runner(self): thread = currentThread() - task = self.task_list.get() - - self.task_runner(task, thread) + task: Task = self.task_list.get() + self.prepare_task(task) + t0 = time.perf_counter() + tD = t0 + self.data = defaultdict(list) + while getattr(thread, "do_run"): + tN = time.perf_counter() + if tN - tD > task.sampling_interval: + self.do_task(task, t0=t0, tN=tN, tD=tD) + tD += task.sampling_interval + if tN - t0 > task.max_duration: + break + time.sleep(max(1e-2, task.sampling_interval / 10)) self.task_list.task_done() self.running = False - self.thread = Thread(target=self._worker_wrapper, daemon=True) + self.thread = Thread(target=self.task_runner, daemon=True) + + def prepare_task(self, task: Task, **kwargs: dict): + for k, v in task.technique_params.items(): + self.set_attr(attr=k, val=v) + + @abstractmethod + def do_task(self, task: Task, **kwargs: dict): + pass + + def stop_task(self, **kwargs: dict): + setattr(self.thread, "do_run", False) + + @abstractmethod + def set_attr(self, attr: str, val: Any, **kwargs: dict): + pass @abstractmethod - def task_runner(task: Task, thread: Thread): + def get_attr(self, attr: str, **kwargs: dict) -> Any: pass + def get_data(self, **kwargs: dict) -> dict[str, list]: + ret = self.data + self.data = defaultdict(list) + return ret + + @abstractmethod + def attrs(**kwargs) -> dict: + pass + + @abstractmethod + def tasks(**kwargs) -> set: + pass + + def status(self, **kwargs) -> dict: + status = {} + for attr, props in self.attrs().items(): + if props.status: + status[attr] = self.get_attr(attr) + return status + def CreateDeviceInterface(self, key, **kwargs): """Factory function which passes DriverInterface to the DeviceInterface""" return self.DeviceInterface(self, key, **kwargs) @@ -80,7 +128,7 @@ def CreateDeviceInterface(self, key, **kwargs): devmap: dict[tuple, DeviceInterface] """Map of registered devices, the tuple keys are components = (address, channel)""" - settings: dict[str, str] + settings: dict[str, Any] """A settings map to contain driver-specific settings such as `dllpath` for BioLogic""" def __init__(self, settings=None): @@ -94,7 +142,7 @@ def dev_register(self, address: str, channel: int, **kwargs: dict) -> None: updating existing channels in :obj:`self.devmap`. """ key = (address, channel) - self.devmap[(address, channel)] = self.CreateDeviceInterface(key, **kwargs) + self.devmap[key] = self.CreateDeviceInterface(key, **kwargs) def dev_teardown(self, address: str, channel: int, **kwargs: dict) -> None: """ @@ -105,44 +153,42 @@ def dev_teardown(self, address: str, channel: int, **kwargs: dict) -> None: """ pass - @abstractmethod - def attrs(self, address: str, channel: int, **kwargs) -> dict[str, Attr]: - """ - Function that returns all gettable and settable attributes, their rw status, - and whether they are to be returned in :func:`self.dev_status`. All attrs are - returned by :func:`self.dev_get_data`. - - This is the "low level" control interface, intended for the device dashboard. - - Example: - :: - - return dict( - delay = self.Attr(type=float, rw=True, status=False), - time = self.Attr(type=float, rw=True, status=False), - started = self.Attr(type=bool, rw=True, status=True), - val = self.Attr(type=int, rw=False, status=True), - ) - - """ - pass + @in_devmap + def attrs(self, address: str, channel: int, **kwargs) -> Reply | None: + key = (address, channel) + ret = self.devmap[key].attrs(**kwargs) + return Reply( + success=True, + msg=f"attrs of component {key} are: {ret}", + data=ret, + ) @in_devmap - def dev_set_attr(self, attr: str, val: Any, address: str, channel: int, **kwargs): + def dev_set_attr( + self, attr: str, val: Any, address: str, channel: int, **kwargs + ) -> Reply | None: key = (address, channel) - if attr in self.attrs(): - params = self.attrs()[attr] - if params.rw and isinstance(val, params.type): - self.devmap[key].status[attr] = val + self.devmap[key].set_attr(attr=attr, val=val, **kwargs) + return Reply( + success=True, + msg=f"attr {attr!r} of component {key} set to {val}", + data=val, + ) @in_devmap - def dev_get_attr(self, attr: str, address: str, channel: int, **kwargs): + def dev_get_attr( + self, attr: str, address: str, channel: int, **kwargs + ) -> Reply | None: key = (address, channel) - if attr in self.attrs(address=address, channel=channel): - return self.devmap[key].status[attr] + ret = self.devmap[key].get_attr(attr=attr, **kwargs) + return Reply( + success=True, + msg=f"attr {attr!r} of component {key} is: {ret}", + data=ret, + ) @in_devmap - def dev_status(self, address: str, channel: int, **kwargs): + def dev_status(self, address: str, channel: int, **kwargs) -> Reply | None: key = (address, channel) running = self.devmap[key].running return Reply( @@ -152,12 +198,15 @@ def dev_status(self, address: str, channel: int, **kwargs): ) @in_devmap - def task_start(self, address: str, channel: int, task: Task, **kwargs): - if task.technique_name not in self.tasks(address=address, channel=channel): + def task_start( + self, address: str, channel: int, task: Task, **kwargs + ) -> Reply | None: + key = (address, channel) + if task.technique_name not in self.devmap[key].tasks(**kwargs): return Reply( success=False, msg=f"unknown task {task.technique_name!r} requested", - data=self.tasks(), + data=self.tasks(address=address, channel=channel), ) key = (address, channel) @@ -179,8 +228,11 @@ def task_status(self, address: str, channel: int): return Reply(success=True, msg="running") @in_devmap - def task_stop(self, address: str, channel: int): - self.dev_set_attr(attr="started", val=False, address=address, channel=channel) + def task_stop(self, address: str, channel: int, **kwargs) -> Reply | None: + key = (address, channel) + ret = self.devmap[key].stop_task(**kwargs) + if ret is not None: + return Reply(success=False, msg="failed to stop task", data=ret) ret = self.task_data(self, address, channel) if ret.success: @@ -189,27 +241,16 @@ def task_stop(self, address: str, channel: int): return Reply(success=True, msg=f"task stopped, {ret.msg}") @in_devmap - def task_data(self, address: str, channel: int, **kwargs): + def task_data(self, address: str, channel: int, **kwargs) -> Reply | None: key = (address, channel) - data = self.devmap[key].data - self.devmap[key].data = [] + data = self.devmap[key].get_data(**kwargs) if len(data) == 0: return Reply(success=False, msg="found no new datapoints") - data_vars = {} - for ii, item in enumerate(data): - for k, v in item.items(): - if k not in data_vars: - data_vars[k] = [None] * ii - data_vars[k].append(v) - for k in data_vars: - if k not in item: - data_vars[k].append(None) - - uts = {"uts": data_vars.pop("uts")} - data_vars = {k: ("uts", v) for k, v in data_vars.items()} - ds = Dataset(data_vars=data_vars, coords=uts) + uts = {"uts": data.pop("uts")} + data = {k: ("uts", v) for k, v in data.items()} + ds = Dataset(data_vars=data, coords=uts) return Reply(success=True, msg=f"found {len(data)} new datapoints", data=ds) def status(self): @@ -237,19 +278,11 @@ def dev_get_data(self, address: str, channel: int, **kwargs): ) return ret - @abstractmethod def tasks(self, address: str, channel: int, **kwargs) -> dict: - """ - Function that returns all tasks that can be submitted to the Device. This - implements the driver specific language. Each task in tasks can only contain - elements present in :func:`self.attrs`. - - Example: - :: - - return dict( - count = dict(time = dict(type=float), delay = dict(type=float), - ) - - """ - pass + key = (address, channel) + ret = self.devmap[key].tasks(**kwargs) + return Reply( + success=True, + msg=f"tasks supported by component {key} are: {ret}", + data=ret, + ) diff --git a/src/tomato/drivers.py b/src/tomato/drivers.py index 0d7d8401..33e8c32b 100644 --- a/src/tomato/drivers.py +++ b/src/tomato/drivers.py @@ -6,12 +6,12 @@ import logging from typing import Union -from tomato.driverinterface_1_0 import ModelInterface +from tomato.driverinterface_1_0 import DriverInterface logger = logging.getLogger(__name__) -def driver_to_interface(drivername: str) -> Union[None, ModelInterface]: +def driver_to_interface(drivername: str) -> Union[None, DriverInterface]: modname = f"tomato_{drivername.replace('-', '_')}" try: diff --git a/src/tomato/tomato/__init__.py b/src/tomato/tomato/__init__.py index 90d214f3..39bca3e6 100644 --- a/src/tomato/tomato/__init__.py +++ b/src/tomato/tomato/__init__.py @@ -395,7 +395,7 @@ def reload( logger.debug(f"{ret=}") if ret.success is False: return ret - params = dev.dict() + params = dev.model_dump() ret = _updater(context, port, "device", params) if ret.success is False: return ret @@ -408,7 +408,7 @@ def reload( for pip in pips.values(): logger.debug(f"{pip=}") if pip.name not in daemon.pips: - ret = _updater(context, port, "pipeline", pip.dict()) + ret = _updater(context, port, "pipeline", pip.model_dump()) if ret.success is False: return ret else: From 640bfdf7f9053b626957bef92d9d859c230a9c8f Mon Sep 17 00:00:00 2001 From: Peter Kraus Date: Mon, 15 Jul 2024 20:03:39 +0200 Subject: [PATCH 10/17] Fix 3.9 --- src/tomato/driverinterface_1_0/__init__.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/tomato/driverinterface_1_0/__init__.py b/src/tomato/driverinterface_1_0/__init__.py index b48db762..9b00961f 100644 --- a/src/tomato/driverinterface_1_0/__init__.py +++ b/src/tomato/driverinterface_1_0/__init__.py @@ -1,5 +1,5 @@ from abc import ABCMeta, abstractmethod -from typing import TypeVar, Any +from typing import TypeVar, Any, Union from pydantic import BaseModel from threading import Thread, currentThread from queue import Queue @@ -154,7 +154,7 @@ def dev_teardown(self, address: str, channel: int, **kwargs: dict) -> None: pass @in_devmap - def attrs(self, address: str, channel: int, **kwargs) -> Reply | None: + def attrs(self, address: str, channel: int, **kwargs) -> Union[Reply, None]: key = (address, channel) ret = self.devmap[key].attrs(**kwargs) return Reply( @@ -166,7 +166,7 @@ def attrs(self, address: str, channel: int, **kwargs) -> Reply | None: @in_devmap def dev_set_attr( self, attr: str, val: Any, address: str, channel: int, **kwargs - ) -> Reply | None: + ) -> Union[Reply, None]: key = (address, channel) self.devmap[key].set_attr(attr=attr, val=val, **kwargs) return Reply( @@ -178,7 +178,7 @@ def dev_set_attr( @in_devmap def dev_get_attr( self, attr: str, address: str, channel: int, **kwargs - ) -> Reply | None: + ) -> Union[Reply, None]: key = (address, channel) ret = self.devmap[key].get_attr(attr=attr, **kwargs) return Reply( @@ -188,7 +188,7 @@ def dev_get_attr( ) @in_devmap - def dev_status(self, address: str, channel: int, **kwargs) -> Reply | None: + def dev_status(self, address: str, channel: int, **kwargs) -> Union[Reply, None]: key = (address, channel) running = self.devmap[key].running return Reply( @@ -200,7 +200,7 @@ def dev_status(self, address: str, channel: int, **kwargs) -> Reply | None: @in_devmap def task_start( self, address: str, channel: int, task: Task, **kwargs - ) -> Reply | None: + ) -> Union[Reply, None]: key = (address, channel) if task.technique_name not in self.devmap[key].tasks(**kwargs): return Reply( @@ -228,7 +228,7 @@ def task_status(self, address: str, channel: int): return Reply(success=True, msg="running") @in_devmap - def task_stop(self, address: str, channel: int, **kwargs) -> Reply | None: + def task_stop(self, address: str, channel: int, **kwargs) -> Union[Reply, None]: key = (address, channel) ret = self.devmap[key].stop_task(**kwargs) if ret is not None: @@ -241,7 +241,7 @@ def task_stop(self, address: str, channel: int, **kwargs) -> Reply | None: return Reply(success=True, msg=f"task stopped, {ret.msg}") @in_devmap - def task_data(self, address: str, channel: int, **kwargs) -> Reply | None: + def task_data(self, address: str, channel: int, **kwargs) -> Union[Reply, None]: key = (address, channel) data = self.devmap[key].get_data(**kwargs) From c80a6c47ac3db5007e74e991c823dcec424f8c24 Mon Sep 17 00:00:00 2001 From: Peter Kraus Date: Mon, 15 Jul 2024 20:24:37 +0200 Subject: [PATCH 11/17] Disable autodoc_pydantic for now --- docs/source/conf.py | 2 +- src/tomato/drivers.py | 6 +++++- src/tomato/models.py | 6 ++++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index bd706d06..b73ed9e0 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -39,7 +39,7 @@ # "sphinx.ext.autosummary", # "sphinx_autodoc_typehints", "sphinx_rtd_theme", - "sphinxcontrib.autodoc_pydantic", + # "sphinxcontrib.autodoc_pydantic", "sphinxcontrib.mermaid", ] diff --git a/src/tomato/drivers.py b/src/tomato/drivers.py index 33e8c32b..fa2cc208 100644 --- a/src/tomato/drivers.py +++ b/src/tomato/drivers.py @@ -1,5 +1,9 @@ """ -Driver documentation goes here. +**tomato.drivers**: Shim interfacing with tomato driver packages +---------------------------------------------------------------- +.. codeauthor:: + Peter Kraus + """ import importlib diff --git a/src/tomato/models.py b/src/tomato/models.py index e27d99cb..49fe5401 100644 --- a/src/tomato/models.py +++ b/src/tomato/models.py @@ -1,3 +1,9 @@ +""" +**tomato.models**: Pydantic models for internal tomato use +---------------------------------------------------------- +.. codeauthor:: + Peter Kraus +""" from pydantic import BaseModel, Field from typing import Optional, Any, Mapping, Sequence, Literal from pathlib import Path From d6febd430cbb9bb1af55070263ed83f8c4fa0aee Mon Sep 17 00:00:00 2001 From: Peter Kraus Date: Mon, 15 Jul 2024 20:36:10 +0200 Subject: [PATCH 12/17] Update to newer model. --- docs/source/conf.py | 2 +- docs/source/quickstart.rst | 2 +- src/tomato/models.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index b73ed9e0..bd706d06 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -39,7 +39,7 @@ # "sphinx.ext.autosummary", # "sphinx_autodoc_typehints", "sphinx_rtd_theme", - # "sphinxcontrib.autodoc_pydantic", + "sphinxcontrib.autodoc_pydantic", "sphinxcontrib.mermaid", ] diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 0d4ce6c7..8f9f2655 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -325,4 +325,4 @@ The *payload* file contains all information required to enter a *job* into the q allow its assignment onto a *pipeline*. The overall schema of the *payload* is defined in the :mod:`dgbowl_schemas.tomato` module, and is parsed using :func:`dgbowl_schemas.tomato.to_payload`: -.. autopydantic_model:: dgbowl_schemas.tomato.payload_0_2.Payload \ No newline at end of file +.. autopydantic_model:: dgbowl_schemas.tomato.payload_1_0.Payload \ No newline at end of file diff --git a/src/tomato/models.py b/src/tomato/models.py index 49fe5401..61d1d43e 100644 --- a/src/tomato/models.py +++ b/src/tomato/models.py @@ -4,6 +4,7 @@ .. codeauthor:: Peter Kraus """ + from pydantic import BaseModel, Field from typing import Optional, Any, Mapping, Sequence, Literal from pathlib import Path From a0e5ba9ec614b0d68d77492e610ff4dec52f77a8 Mon Sep 17 00:00:00 2001 From: Peter Kraus Date: Tue, 16 Jul 2024 15:07:47 +0200 Subject: [PATCH 13/17] Changes to driverinterface --- pyproject.toml | 8 +++- src/tomato/driverinterface_1_0/__init__.py | 28 +++++++++----- tests/common/devices_psutil.json | 29 ++++++++++++++ tests/common/psutil_1_0.1.yml | 10 +++++ tests/common/psutil_counter.yml | 21 ++++++++++ tests/conftest.py | 2 +- tests/test_99_psutil.py | 45 ++++++++++++++++++++++ 7 files changed, 131 insertions(+), 12 deletions(-) create mode 100644 tests/common/devices_psutil.json create mode 100644 tests/common/psutil_1_0.1.yml create mode 100644 tests/common/psutil_counter.yml create mode 100644 tests/test_99_psutil.py diff --git a/pyproject.toml b/pyproject.toml index 54c09684..c3c3140d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,11 +33,15 @@ dependencies = [ "h5netcdf >= 1.3", "xarray >= 2024.2", "pydantic >= 2.0", - "tomato-example-counter @ git+https://github.com/dgbowl/tomato-example-counter.git", + ] [project.optional-dependencies] -testing = ["pytest"] +testing = [ + "pytest", + "tomato-example-counter @ git+https://github.com/dgbowl/tomato-example-counter.git", + "tomato-psutil @ git+https://github.com/dgbowl/tomato-psutil.git", +] docs = [ "sphinx ~= 7.2", "sphinx-rtd-theme ~= 1.3.0", diff --git a/src/tomato/driverinterface_1_0/__init__.py b/src/tomato/driverinterface_1_0/__init__.py index 9b00961f..fc9cdc88 100644 --- a/src/tomato/driverinterface_1_0/__init__.py +++ b/src/tomato/driverinterface_1_0/__init__.py @@ -1,7 +1,7 @@ from abc import ABCMeta, abstractmethod from typing import TypeVar, Any, Union from pydantic import BaseModel -from threading import Thread, currentThread +from threading import Thread, currentThread, RLock from queue import Queue from tomato.models import Reply from dgbowl_schemas.tomato.payload import Task @@ -36,6 +36,7 @@ class Attr(BaseModel): type: T rw: bool = False status: bool = False + units: str = None class DriverInterface(metaclass=ABCMeta): @@ -44,6 +45,7 @@ class DriverInterface(metaclass=ABCMeta): class DeviceInterface(metaclass=ABCMeta): driver: object data: dict[str, list] + datalock: RLock key: tuple thread: Thread task_list: Queue @@ -56,6 +58,7 @@ def __init__(self, driver, key, **kwargs): self.thread = Thread(target=self.task_runner, daemon=True) self.data = defaultdict(list) self.running = False + self.datalock = RLock() def run(self): self.thread.do_run = True @@ -72,7 +75,8 @@ def task_runner(self): while getattr(thread, "do_run"): tN = time.perf_counter() if tN - tD > task.sampling_interval: - self.do_task(task, t0=t0, tN=tN, tD=tD) + with self.datalock: + self.do_task(task, t0=t0, tN=tN, tD=tD) tD += task.sampling_interval if tN - t0 > task.max_duration: break @@ -83,8 +87,9 @@ def task_runner(self): self.thread = Thread(target=self.task_runner, daemon=True) def prepare_task(self, task: Task, **kwargs: dict): - for k, v in task.technique_params.items(): - self.set_attr(attr=k, val=v) + if task.technique_params is not None: + for k, v in task.technique_params.items(): + self.set_attr(attr=k, val=v) @abstractmethod def do_task(self, task: Task, **kwargs: dict): @@ -102,12 +107,13 @@ def get_attr(self, attr: str, **kwargs: dict) -> Any: pass def get_data(self, **kwargs: dict) -> dict[str, list]: - ret = self.data - self.data = defaultdict(list) + with self.datalock: + ret = self.data + self.data = defaultdict(list) return ret @abstractmethod - def attrs(**kwargs) -> dict: + def attrs(**kwargs) -> dict[str, Attr]: pass @abstractmethod @@ -248,9 +254,13 @@ def task_data(self, address: str, channel: int, **kwargs) -> Union[Reply, None]: if len(data) == 0: return Reply(success=False, msg="found no new datapoints") + attrs = self.devmap[key].attrs(**kwargs) uts = {"uts": data.pop("uts")} - data = {k: ("uts", v) for k, v in data.items()} - ds = Dataset(data_vars=data, coords=uts) + data_vars = {} + for k, v in data.items(): + units = {} if attrs[k].units is None else {"units": attrs[k].units} + data_vars[k] = ("uts", v, units) + ds = Dataset(data_vars=data_vars, coords=uts) return Reply(success=True, msg=f"found {len(data)} new datapoints", data=ds) def status(self): diff --git a/tests/common/devices_psutil.json b/tests/common/devices_psutil.json new file mode 100644 index 00000000..615b0e3d --- /dev/null +++ b/tests/common/devices_psutil.json @@ -0,0 +1,29 @@ +{ + "devices": [ + { + "name": "dev-counter", + "driver": "example_counter", + "address": "counter-addr", + "channels": [1], + "capabilities": ["count", "random"], + "pollrate": 1 + }, + { + "name": "dev-psutil", + "driver": "psutil", + "address": "psutil-addr", + "channels": [10], + "capabilities": ["all_info", "cpu_info", "mem_info"], + "pollrate": 1 + } + ], + "pipelines": [ + { + "name": "pip-multidev", + "devices": [ + {"tag": "counter", "name": "dev-counter", "channel": 1}, + {"tag": "psutil", "name": "dev-psutil", "channel": 10} + ] + } + ] +} \ No newline at end of file diff --git a/tests/common/psutil_1_0.1.yml b/tests/common/psutil_1_0.1.yml new file mode 100644 index 00000000..ba2518e4 --- /dev/null +++ b/tests/common/psutil_1_0.1.yml @@ -0,0 +1,10 @@ +version: "1.0" +sample: + name: psutil_1_0.1 +method: + - component_tag: "psutil" + technique_name: "all_info" + max_duration: 1.0 + sampling_interval: 0.1 +settings: + verbosity: "DEBUG" \ No newline at end of file diff --git a/tests/common/psutil_counter.yml b/tests/common/psutil_counter.yml new file mode 100644 index 00000000..b6a4affc --- /dev/null +++ b/tests/common/psutil_counter.yml @@ -0,0 +1,21 @@ +version: "1.0" +sample: + name: psutil_counter +method: + - component_tag: "psutil" + technique_name: "all_info" + max_duration: 1.0 + sampling_interval: 0.1 + - component_tag: "psutil" + technique_name: "all_info" + max_duration: 1.0 + sampling_interval: 0.5 + - component_tag: "counter" + technique_name: "random" + max_duration: 2.0 + sampling_interval: 0.2 + technique_params: + min: 50.0 + max: 100.0 +settings: + verbosity: "DEBUG" \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index c6c6f7fe..708d787b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,7 +30,7 @@ def start_tomato_daemon(tmpdir: str, port: int = 12345): # setup_stuff os.chdir(tmpdir) subprocess.run(["tomato", "init", "-p", f"{port}", "-A", ".", "-D", "."]) - subprocess.run(["tomato", "start", "-p", f"{port}", "-A", ".", "-L", ".", "-VV"]) + subprocess.run(["tomato", "start", "-p", f"{port}", "-A", ".", "-L", ".", "-vv"]) yield # teardown_stuff diff --git a/tests/test_99_psutil.py b/tests/test_99_psutil.py new file mode 100644 index 00000000..446b5a8d --- /dev/null +++ b/tests/test_99_psutil.py @@ -0,0 +1,45 @@ +import pytest +import os +import subprocess +import json +import yaml +import xarray as xr + +from . import utils + +PORT = 12345 + + +@pytest.mark.parametrize( + "casename, npoints", + [ + ("psutil_1_0.1", {"psutil": 10}), + ("psutil_counter", {"psutil": 12, "counter": 10}), + ], +) +def test_psutil_multidev(casename, npoints, datadir, stop_tomato_daemon): + os.chdir(datadir) + with open("devices_psutil.json", "r") as inf: + jsdata = json.load(inf) + with open("devices.yml", "w") as ouf: + yaml.dump(jsdata, ouf) + subprocess.run(["tomato", "init", "-p", f"{PORT}", "-A", ".", "-D", "."]) + subprocess.run(["tomato", "start", "-p", f"{PORT}", "-A", ".", "-L", ".", "-vv"]) + utils.wait_until_tomato_running(port=PORT, timeout=3000) + + utils.run_casenames([casename], [None], ["pip-multidev"]) + utils.wait_until_ketchup_status(jobid=1, status="r", port=PORT, timeout=2000) + utils.wait_until_ketchup_status(jobid=1, status="c", port=PORT, timeout=2000) + + ret = utils.job_status(1) + print(f"{ret=}") + status = utils.job_status(1)["data"][1]["status"] + assert status == "c" + files = os.listdir(os.path.join(".", "Jobs", "1")) + assert "jobdata.json" in files + assert "job-1.log" in files + assert os.path.exists("results.1.nc") + for group, points in npoints.items(): + ds = xr.load_dataset("results.1.nc", group=group) + print(f"{ds=}") + assert ds["uts"].size == points From e757a10a3f69d824a4e18b2d7f506e4f971a4f55 Mon Sep 17 00:00:00 2001 From: Peter Kraus Date: Tue, 16 Jul 2024 15:24:00 +0200 Subject: [PATCH 14/17] Consistency ModelInterface --- src/tomato/driverinterface_1_0/__init__.py | 22 +++++++++++----------- src/tomato/drivers.py | 4 ++-- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/tomato/driverinterface_1_0/__init__.py b/src/tomato/driverinterface_1_0/__init__.py index fc9cdc88..070e0f65 100644 --- a/src/tomato/driverinterface_1_0/__init__.py +++ b/src/tomato/driverinterface_1_0/__init__.py @@ -39,10 +39,10 @@ class Attr(BaseModel): units: str = None -class DriverInterface(metaclass=ABCMeta): +class ModelInterface(metaclass=ABCMeta): version: str = "1.0" - class DeviceInterface(metaclass=ABCMeta): + class DeviceManager(metaclass=ABCMeta): driver: object data: dict[str, list] datalock: RLock @@ -117,7 +117,7 @@ def attrs(**kwargs) -> dict[str, Attr]: pass @abstractmethod - def tasks(**kwargs) -> set: + def capabilities(**kwargs) -> set: pass def status(self, **kwargs) -> dict: @@ -127,11 +127,11 @@ def status(self, **kwargs) -> dict: status[attr] = self.get_attr(attr) return status - def CreateDeviceInterface(self, key, **kwargs): + def CreateDeviceManager(self, key, **kwargs): """Factory function which passes DriverInterface to the DeviceInterface""" - return self.DeviceInterface(self, key, **kwargs) + return self.DeviceManager(self, key, **kwargs) - devmap: dict[tuple, DeviceInterface] + devmap: dict[tuple, DeviceManager] """Map of registered devices, the tuple keys are components = (address, channel)""" settings: dict[str, Any] @@ -144,11 +144,11 @@ def __init__(self, settings=None): def dev_register(self, address: str, channel: int, **kwargs: dict) -> None: """ Register a Device and its Component in this DriverInterface, creating a - :obj:`self.DeviceInterface` object in the :obj:`self.devmap` if necessary, or + :obj:`self.DeviceManager` object in the :obj:`self.devmap` if necessary, or updating existing channels in :obj:`self.devmap`. """ key = (address, channel) - self.devmap[key] = self.CreateDeviceInterface(key, **kwargs) + self.devmap[key] = self.CreateDeviceManager(key, **kwargs) def dev_teardown(self, address: str, channel: int, **kwargs: dict) -> None: """ @@ -208,7 +208,7 @@ def task_start( self, address: str, channel: int, task: Task, **kwargs ) -> Union[Reply, None]: key = (address, channel) - if task.technique_name not in self.devmap[key].tasks(**kwargs): + if task.technique_name not in self.devmap[key].capabilities(**kwargs): return Reply( success=False, msg=f"unknown task {task.technique_name!r} requested", @@ -288,11 +288,11 @@ def dev_get_data(self, address: str, channel: int, **kwargs): ) return ret - def tasks(self, address: str, channel: int, **kwargs) -> dict: + def capabilities(self, address: str, channel: int, **kwargs) -> dict: key = (address, channel) ret = self.devmap[key].tasks(**kwargs) return Reply( success=True, - msg=f"tasks supported by component {key} are: {ret}", + msg=f"capabilities supported by component {key!r} are: {ret}", data=ret, ) diff --git a/src/tomato/drivers.py b/src/tomato/drivers.py index fa2cc208..5d279121 100644 --- a/src/tomato/drivers.py +++ b/src/tomato/drivers.py @@ -10,12 +10,12 @@ import logging from typing import Union -from tomato.driverinterface_1_0 import DriverInterface +from tomato.driverinterface_1_0 import ModelInterface logger = logging.getLogger(__name__) -def driver_to_interface(drivername: str) -> Union[None, DriverInterface]: +def driver_to_interface(drivername: str) -> Union[None, ModelInterface]: modname = f"tomato_{drivername.replace('-', '_')}" try: From adf963c994a8f309066e0180b97ff8eac3cfd206 Mon Sep 17 00:00:00 2001 From: Peter Kraus Date: Tue, 16 Jul 2024 18:49:57 +0200 Subject: [PATCH 15/17] Docs --- src/tomato/daemon/job.py | 40 ++-- src/tomato/driverinterface_1_0/__init__.py | 260 +++++++++++++++------ 2 files changed, 214 insertions(+), 86 deletions(-) diff --git a/src/tomato/daemon/job.py b/src/tomato/daemon/job.py index ee774ad9..239244fb 100644 --- a/src/tomato/daemon/job.py +++ b/src/tomato/daemon/job.py @@ -21,7 +21,7 @@ from importlib import metadata from datetime import datetime, timezone from pathlib import Path -from threading import currentThread, Thread +from threading import current_thread, Thread from typing import Any import zmq import psutil @@ -227,7 +227,7 @@ def manager(port: int, timeout: int = 500): """ context = zmq.Context() logger = logging.getLogger(f"{__name__}.manager") - thread = currentThread() + thread = current_thread() logger.info("launched successfully") req = context.socket(zmq.REQ) req.connect(f"tcp://127.0.0.1:{port}") @@ -423,12 +423,7 @@ def job_thread( Stores the data for that Component as a `pickle` of a :class:`xr.Dataset`. """ - sender = f"{__name__}.job_thread" - # logging.basicConfig( - # level=logging.DEBUG, - # format="%(asctime)s - %(levelname)8s - %(name)-30s - %(message)s", - # handlers=[logging.FileHandler(logpath, mode="a")], - # ) + sender = f"{__name__}.job_thread({current_thread().ident})" logger = logging.getLogger(sender) logger.debug(f"in job thread of {component.role!r}") @@ -444,36 +439,43 @@ def job_thread( for task in tasks: logger.debug(f"{task=}") while True: + logger.debug("polling component '%s' for task readiness", component.role) req.send_pyobj(dict(cmd="task_status", params={**kwargs})) ret = req.recv_pyobj() - logger.debug(f"{ret=}") - if ret.success and ret.msg == "ready": + if ret.success and ret.data["can_submit"]: break - + logger.warning("cannot submit onto component '%s', waiting", component.role) + time.sleep(1e-1) + logger.debug("sending task to component '%s'", component.role) req.send_pyobj(dict(cmd="task_start", params={"task": task, **kwargs})) ret = req.recv_pyobj() - logger.debug(f"{ret=}") t0 = time.perf_counter() while True: tN = time.perf_counter() if tN - t0 > device.pollrate: + logger.debug("polling component '%s' for data", component.role) req.send_pyobj(dict(cmd="task_data", params={**kwargs})) ret = req.recv_pyobj() if ret.success: + logger.debug("pickling received data") data_to_pickle(ret.data, datapath, role=component.role) t0 += device.pollrate + + logger.debug("polling component '%s' for task completion", component.role) req.send_pyobj(dict(cmd="task_status", params={**kwargs})) ret = req.recv_pyobj() - logger.debug(f"{ret=}") - if ret.success and ret.msg == "ready": + if ret.success and not ret.data["running"]: + logger.debug("task no longer running, break") break - time.sleep(device.pollrate - (tN - t0)) - logger.debug("tock") + time.sleep(max(1e-1, (device.pollrate - (tN - t0)) / 2)) + + logger.debug("fetching final data for task") req.send_pyobj(dict(cmd="task_data", params={**kwargs})) ret = req.recv_pyobj() if ret.success: data_to_pickle(ret.data, datapath, role=component.role) + logger.debug("all tasks done on component '%s'", component.role) def job_main_loop( @@ -519,11 +521,11 @@ def job_main_loop( threads = {} for role, tasks in plan.items(): component = pipeline.devs[role] - logger.debug(f"{component=}") + logger.debug(" component=%s", component) device = daemon.devs[component.name] - logger.debug(f"{device=}") + logger.debug(" device=%s", device) driver = daemon.drvs[device.driver] - logger.debug(f"{driver=}") + logger.debug(" driver=%s", driver) threads[role] = Thread( target=job_thread, args=(tasks, component, device, driver, jobpath, logpath), diff --git a/src/tomato/driverinterface_1_0/__init__.py b/src/tomato/driverinterface_1_0/__init__.py index 070e0f65..3bbae81b 100644 --- a/src/tomato/driverinterface_1_0/__init__.py +++ b/src/tomato/driverinterface_1_0/__init__.py @@ -1,7 +1,7 @@ from abc import ABCMeta, abstractmethod from typing import TypeVar, Any, Union from pydantic import BaseModel -from threading import Thread, currentThread, RLock +from threading import Thread, current_thread, RLock from queue import Queue from tomato.models import Reply from dgbowl_schemas.tomato.payload import Task @@ -17,12 +17,16 @@ def in_devmap(func): @wraps(func) def wrapper(self, **kwargs): - address = kwargs.get("address") - channel = kwargs.get("channel") - if (address, channel) not in self.devmap: + if "key" in kwargs: + key = kwargs.pop("key") + else: + address = kwargs.get("address") + channel = kwargs.get("channel") + key = (address, channel) + if key not in self.devmap: msg = f"dev with address {address!r} and channel {channel} is unknown" return Reply(success=False, msg=msg, data=self.devmap.keys()) - return func(self, **kwargs) + return func(self, **kwargs, key=key) return wrapper @@ -31,7 +35,7 @@ def wrapper(self, **kwargs): class Attr(BaseModel): - """Class used to describe device attributes.""" + """A Pydantic :class:`BaseModel` used to describe device attributes.""" type: T rw: bool = False @@ -40,15 +44,39 @@ class Attr(BaseModel): class ModelInterface(metaclass=ABCMeta): + """ + An abstract base class specifying the a driver interface. + + Individual driver modules should expose a :class:`DriverInterface` which inherits + from this abstract class. Only the methods of this class should be used to interact + with drivers and their devices. + """ + version: str = "1.0" class DeviceManager(metaclass=ABCMeta): - driver: object + """ + An abstract base class specifying a manager for an individual component. + """ + + driver: super + """The parent :class:`DriverInterface` instance.""" + data: dict[str, list] + """Container for cached data on this component.""" + datalock: RLock + """Lock object for thread-safe data manipulation.""" + key: tuple + """The key in :obj:`driver.devmap` referring to this object.""" + thread: Thread + """The worker :class:`Thread`.""" + task_list: Queue + """A :class:`Queue` used to pass :class:`Tasks` to the worker :class:`Thread`.""" + running: bool def __init__(self, driver, key, **kwargs): @@ -61,52 +89,80 @@ def __init__(self, driver, key, **kwargs): self.datalock = RLock() def run(self): + """Helper function for starting the :obj:`self.thread`.""" self.thread.do_run = True self.thread.start() self.running = True def task_runner(self): - thread = currentThread() + """ + Target function for the :obj:`self.thread`. + + This function waits for a :class:`Task` passed using :obj:`self.task_list`, + then handles setting all :class:`Attrs` using the :func:`prepare_task` + function, and finally handles the main loop of the task, periodically running + the :func:`do_task` function (using `task.sampling_interval`) until the + maximum task duration (i.e. `task.max_duration`) is exceeded. + + The :obj:`self.thread` is re-primed for future :class:`Tasks` at the end + of this function. + """ + thread = current_thread() task: Task = self.task_list.get() self.prepare_task(task) - t0 = time.perf_counter() - tD = t0 + t_start = time.perf_counter() + t_prev = t_start self.data = defaultdict(list) while getattr(thread, "do_run"): - tN = time.perf_counter() - if tN - tD > task.sampling_interval: + t_now = time.perf_counter() + if t_now - t_prev > task.sampling_interval: with self.datalock: - self.do_task(task, t0=t0, tN=tN, tD=tD) - tD += task.sampling_interval - if tN - t0 > task.max_duration: + self.do_task(task, t_start=t_start, t_now=t_now, t_prev=t_prev) + t_prev += task.sampling_interval + if t_now - t_start > task.max_duration: break - time.sleep(max(1e-2, task.sampling_interval / 10)) + time.sleep(max(1e-2, task.sampling_interval / 20)) self.task_list.task_done() self.running = False self.thread = Thread(target=self.task_runner, daemon=True) def prepare_task(self, task: Task, **kwargs: dict): + """ + Given a :class:`Task`, prepare this component for execution by settin all + :class:`Attrs` as specified in the `task.technique_params` dictionary. + """ if task.technique_params is not None: for k, v in task.technique_params.items(): self.set_attr(attr=k, val=v) @abstractmethod def do_task(self, task: Task, **kwargs: dict): + """ + Periodically called task execution function. + + This function is responsible for updating :obj:`self.data` with new data, i.e. + performing the measurement. It should also update the values of all + :class:`Attrs`, so that the component status is consistent with the cached data. + """ pass def stop_task(self, **kwargs: dict): + """Stops the currently running task.""" setattr(self.thread, "do_run", False) @abstractmethod def set_attr(self, attr: str, val: Any, **kwargs: dict): + """Sets the specified :class:`Attr` to `val`.""" pass @abstractmethod def get_attr(self, attr: str, **kwargs: dict) -> Any: + """Reads the value of the specified :class:`Attr`.""" pass def get_data(self, **kwargs: dict) -> dict[str, list]: + """Returns the cached :obj:`self.data` before clearing the cache.""" with self.datalock: ret = self.data self.data = defaultdict(list) @@ -114,13 +170,16 @@ def get_data(self, **kwargs: dict) -> dict[str, list]: @abstractmethod def attrs(**kwargs) -> dict[str, Attr]: + """Returns a :class:`dict` of all available :class:`Attrs`.""" pass @abstractmethod def capabilities(**kwargs) -> set: + """Returns a :class:`set` of all supported techniques.""" pass def status(self, **kwargs) -> dict: + """Compiles a status report from :class:`Attrs` marked as `status=True`.""" status = {} for attr, props in self.attrs().items(): if props.status: @@ -128,11 +187,14 @@ def status(self, **kwargs) -> dict: return status def CreateDeviceManager(self, key, **kwargs): - """Factory function which passes DriverInterface to the DeviceInterface""" + """ + A factory function which is used to pass this :class:`ModelInterface` to the new + :class:`DeviceManager` instance. + """ return self.DeviceManager(self, key, **kwargs) devmap: dict[tuple, DeviceManager] - """Map of registered devices, the tuple keys are components = (address, channel)""" + """Map of registered devices, the tuple keys are `component = (address, channel)`""" settings: dict[str, Any] """A settings map to contain driver-specific settings such as `dllpath` for BioLogic""" @@ -143,25 +205,37 @@ def __init__(self, settings=None): def dev_register(self, address: str, channel: int, **kwargs: dict) -> None: """ - Register a Device and its Component in this DriverInterface, creating a - :obj:`self.DeviceManager` object in the :obj:`self.devmap` if necessary, or - updating existing channels in :obj:`self.devmap`. + Register a new device component in this driver. + + Creates a :class:`DeviceManager` representing a device component, storing it in + the :obj:`self.devmap` using the provided `address` and `channel`. """ key = (address, channel) self.devmap[key] = self.CreateDeviceManager(key, **kwargs) - def dev_teardown(self, address: str, channel: int, **kwargs: dict) -> None: + @in_devmap + def dev_teardown(self, key: tuple, **kwargs: dict) -> None: """ - Emergency stop function. Set the device into a documented, safe state. + Emergency stop function. - The function is to be only called in case of critical errors, not as part of - normal operation. + Should set the device component into a documented, safe state. The function is + to be only called in case of critical errors, or when the component is being + removed, not as part of normal operation (i.e. it is not intended as a clean-up + after task completion). """ - pass + status = self.task_status(key=key, **kwargs) + if status.data: + logger.warning("tearing down component '%s' with a running task!") + self.task_stop(key=key, **kwargs) + del self.devmap[key] @in_devmap - def attrs(self, address: str, channel: int, **kwargs) -> Union[Reply, None]: - key = (address, channel) + def attrs(self, key: tuple, **kwargs: dict) -> Union[Reply, None]: + """ + Query available :class:`Attrs` on the specified device component. + + Pass-through to the :func:`DeviceManager.attrs` function. + """ ret = self.devmap[key].attrs(**kwargs) return Reply( success=True, @@ -171,9 +245,14 @@ def attrs(self, address: str, channel: int, **kwargs) -> Union[Reply, None]: @in_devmap def dev_set_attr( - self, attr: str, val: Any, address: str, channel: int, **kwargs + self, attr: str, val: Any, key: tuple, **kwargs: dict ) -> Union[Reply, None]: - key = (address, channel) + """ + Set value of the :class:`Attr` of the specified device component. + + Pass-through to the :func:`DeviceManager.set_attr` function. No type or + read-write validation performed here! + """ self.devmap[key].set_attr(attr=attr, val=val, **kwargs) return Reply( success=True, @@ -182,10 +261,14 @@ def dev_set_attr( ) @in_devmap - def dev_get_attr( - self, attr: str, address: str, channel: int, **kwargs - ) -> Union[Reply, None]: - key = (address, channel) + def dev_get_attr(self, attr: str, key: tuple, **kwargs: dict) -> Union[Reply, None]: + """ + Get value of the :class:`Attr` from the specified device component. + + Pass-through to the :func:`DeviceManager.get_attr` function. Units are not + returned; those can be queried for all :class:`Attrs` using :func:`attrs`. + + """ ret = self.devmap[key].get_attr(attr=attr, **kwargs) return Reply( success=True, @@ -194,28 +277,41 @@ def dev_get_attr( ) @in_devmap - def dev_status(self, address: str, channel: int, **kwargs) -> Union[Reply, None]: - key = (address, channel) - running = self.devmap[key].running + def dev_status(self, key: tuple, **kwargs: dict) -> Union[Reply, None]: + """ + Get the status report from the specified device component. + + Iterates over all :class:`Attrs` on the component that have `status=True` and + returns their values in a :class:`dict`. + """ + ret = {} + for k, attr in self.devmap[key].attrs(key=key, **kwargs).items(): + if attr.status: + ret[k] = self.devmap[key].get_attr(attr=k, **kwargs) + + ret["running"] = self.devmap[key].running return Reply( success=True, - msg=f"component {key} is{' ' if running else ' not ' }running", - data=running, + msg=f"component {key} is{' ' if ret['running'] else ' not ' }running", + data=ret, ) @in_devmap - def task_start( - self, address: str, channel: int, task: Task, **kwargs - ) -> Union[Reply, None]: - key = (address, channel) + def task_start(self, key: tuple, task: Task, **kwargs) -> Union[Reply, None]: + """ + Submit a :class:`Task` onto the specified device component. + + Pushes the supplied :class:`Task` into the :class:`Queue` of the component, + then starts the worker thread. Checks that the :class:`Task` is among the + capabilities of this component. + """ if task.technique_name not in self.devmap[key].capabilities(**kwargs): return Reply( success=False, msg=f"unknown task {task.technique_name!r} requested", - data=self.tasks(address=address, channel=channel), + data=self.capabilities(key=key), ) - key = (address, channel) self.devmap[key].task_list.put(task) self.devmap[key].run() return Reply( @@ -225,30 +321,51 @@ def task_start( ) @in_devmap - def task_status(self, address: str, channel: int): - key = (address, channel) - started = self.devmap[key].running - if not started: - return Reply(success=True, msg="ready") + def task_status(self, key: tuple, **kwargs: dict) -> Reply: + """ + Returns the task readiness status of the specified device component. + + The `running` entry in the data slot of the :class:`Reply` indicates whether + a :class:`Task` is running. The `can_submit` entry indicates whether another + :class:`Task` can be queued onto the device component already. + """ + running = self.devmap[key].running + data = dict(running=running, can_submit=not running) + if running: + return Reply(success=True, msg="running", data=data) else: - return Reply(success=True, msg="running") + return Reply(success=True, msg="ready", data=data) @in_devmap - def task_stop(self, address: str, channel: int, **kwargs) -> Union[Reply, None]: - key = (address, channel) + def task_stop(self, key: tuple, **kwargs) -> Union[Reply, None]: + """ + Stops a running task and returns any collected data. + + Pass-through to :func:`DriverManager.stop_task` and :func:`task_data`. + """ ret = self.devmap[key].stop_task(**kwargs) if ret is not None: return Reply(success=False, msg="failed to stop task", data=ret) - ret = self.task_data(self, address, channel) + ret = self.task_data(self, key=key) if ret.success: return Reply(success=True, msg=f"task stopped, {ret.msg}", data=ret.data) else: return Reply(success=True, msg=f"task stopped, {ret.msg}") @in_devmap - def task_data(self, address: str, channel: int, **kwargs) -> Union[Reply, None]: - key = (address, channel) + def task_data(self, key: tuple, **kwargs) -> Union[Reply, None]: + """ + Return cached task data on the device component and clean the cache. + + Pass-through for :func:`DeviceManager.get_data`, with the caveat that the + :class:`dict[list]` which is returned from the component is here converted to a + :class:`Dataset` and annotated using units from :func:`attrs`. + + This function gets called by the job thread every `device.pollrate`, it therefore + incurs some IPC cost. + + """ data = self.devmap[key].get_data(**kwargs) if len(data) == 0: @@ -264,6 +381,11 @@ def task_data(self, address: str, channel: int, **kwargs) -> Union[Reply, None]: return Reply(success=True, msg=f"found {len(data)} new datapoints", data=ds) def status(self): + """ + Returns the driver status. Currently that is the names of the components in + the `devmap`. + + """ devkeys = self.devmap.keys() return Reply( success=True, @@ -272,25 +394,29 @@ def status(self): ) def teardown(self): + """ + Tears down the driver. + + Called when the driver process is quitting. Instructs all remaining tasks to + stop. Warns when devices linger. This is not a pass-through to :func:`dev_teardown`. + + """ for key, dev in self.devmap.items(): - dev.thread.do_run = False + setattr(dev.thread, "do_run", False) dev.thread.join(1) if dev.thread.is_alive(): logger.error(f"device {key!r} is still alive") else: logger.debug(f"device {key!r} successfully closed") - def dev_get_data(self, address: str, channel: int, **kwargs): - ret = {} - for k in self.attrs(address=address, channel=channel, **kwargs).keys(): - ret[k] = self.dev_get_attr( - attr=k, address=address, channel=channel, **kwargs - ) - return ret + @in_devmap + def capabilities(self, key: tuple, **kwargs) -> dict: + """ + Returns the capabilities of the device component. - def capabilities(self, address: str, channel: int, **kwargs) -> dict: - key = (address, channel) - ret = self.devmap[key].tasks(**kwargs) + Pass-through to :func:`DriverManager.capabilities`. + """ + ret = self.devmap[key].capabilities(**kwargs) return Reply( success=True, msg=f"capabilities supported by component {key!r} are: {ret}", From 91e70acbf212653c1182b0607eb27809721d0620 Mon Sep 17 00:00:00 2001 From: Peter Kraus Date: Wed, 17 Jul 2024 09:11:45 +0200 Subject: [PATCH 16/17] Final changes. --- src/tomato/daemon/driver.py | 25 +++---- src/tomato/daemon/job.py | 6 +- src/tomato/driverinterface_1_0/__init__.py | 83 ++++++++++++++++------ 3 files changed, 74 insertions(+), 40 deletions(-) diff --git a/src/tomato/daemon/driver.py b/src/tomato/daemon/driver.py index f03df824..fc58fc28 100644 --- a/src/tomato/daemon/driver.py +++ b/src/tomato/daemon/driver.py @@ -18,6 +18,7 @@ import zmq import psutil +from tomato.driverinterface_1_0 import ModelInterface from tomato.drivers import driver_to_interface from tomato.models import Reply @@ -100,7 +101,7 @@ def tomato_driver() -> None: if Interface is None: logger.critical(f"library of driver {args.driver!r} not found") return - interface = Interface(settings=daemon.drvs[args.driver].settings) + interface: ModelInterface = Interface(settings=daemon.drvs[args.driver].settings) logger.info(f"registering devices in driver {args.driver!r}") for dev in daemon.devs.values(): @@ -136,7 +137,7 @@ def tomato_driver() -> None: socks = dict(poller.poll(100)) if rep in socks: msg = rep.recv_pyobj() - logger.debug(f"received {msg=}") + logger.debug("received msg=%s", msg) if "cmd" not in msg: logger.error(f"received msg without cmd: {msg=}") ret = Reply(success=False, msg="received msg without cmd", data=msg) @@ -161,26 +162,16 @@ def tomato_driver() -> None: msg="settings received", data=msg.get("params"), ) - elif msg["cmd"] == "dev_register": - interface.dev_register(**msg["params"]) - ret = Reply( - success=True, - msg="device registered", - data=msg.get("params"), - ) - elif msg["cmd"] == "task_status": - ret = interface.task_status(**msg["params"]) - elif msg["cmd"] == "task_start": - ret = interface.task_start(**msg["params"]) - elif msg["cmd"] == "task_data": - ret = interface.task_data(**msg["params"]) + elif hasattr(interface, msg["cmd"]): + ret = getattr(interface, msg["cmd"])(**msg["params"]) + logger.debug("replying Reply(success=%s, msg='%s')", ret.success, ret.msg) rep.send_pyobj(ret) if status == "stop": break - logger.info(f"driver {args.driver!r} is beginning teardown") + logger.info(f"driver {args.driver!r} is beginning reset") - interface.teardown() + interface.reset() logger.critical(f"driver {args.driver!r} is quitting") diff --git a/src/tomato/daemon/job.py b/src/tomato/daemon/job.py index 239244fb..a7c4a016 100644 --- a/src/tomato/daemon/job.py +++ b/src/tomato/daemon/job.py @@ -475,7 +475,11 @@ def job_thread( ret = req.recv_pyobj() if ret.success: data_to_pickle(ret.data, datapath, role=component.role) - logger.debug("all tasks done on component '%s'", component.role) + logger.debug("all tasks done on component '%s', resetting", component.role) + req.send_pyobj(dict(cmd="dev_reset", params={**kwargs})) + ret = req.recv_pyobj() + if not ret.success: + logger.warning("could not reset component '%s': %s", component.role, ret.msg) def job_main_loop( diff --git a/src/tomato/driverinterface_1_0/__init__.py b/src/tomato/driverinterface_1_0/__init__.py index 3bbae81b..e456ff40 100644 --- a/src/tomato/driverinterface_1_0/__init__.py +++ b/src/tomato/driverinterface_1_0/__init__.py @@ -1,5 +1,5 @@ from abc import ABCMeta, abstractmethod -from typing import TypeVar, Any, Union +from typing import TypeVar, Any from pydantic import BaseModel from threading import Thread, current_thread, RLock from queue import Queue @@ -186,6 +186,14 @@ def status(self, **kwargs) -> dict: status[attr] = self.get_attr(attr) return status + def reset(self, **kwargs) -> None: + """Resets the component to an initial status.""" + self.task_list = Queue() + self.thread = Thread(target=self.task_runner, daemon=True) + self.data = defaultdict(list) + self.running = False + self.datalock = RLock() + def CreateDeviceManager(self, key, **kwargs): """ A factory function which is used to pass this :class:`ModelInterface` to the new @@ -203,7 +211,7 @@ def __init__(self, settings=None): self.devmap = {} self.settings = settings if settings is not None else {} - def dev_register(self, address: str, channel: int, **kwargs: dict) -> None: + def dev_register(self, address: str, channel: int, **kwargs: dict) -> Reply: """ Register a new device component in this driver. @@ -212,9 +220,14 @@ def dev_register(self, address: str, channel: int, **kwargs: dict) -> None: """ key = (address, channel) self.devmap[key] = self.CreateDeviceManager(key, **kwargs) + return Reply( + success=True, + msg=f"device {key!r} registered", + data=self.devmap[key], + ) @in_devmap - def dev_teardown(self, key: tuple, **kwargs: dict) -> None: + def dev_teardown(self, key: tuple, **kwargs: dict) -> Reply: """ Emergency stop function. @@ -227,10 +240,30 @@ def dev_teardown(self, key: tuple, **kwargs: dict) -> None: if status.data: logger.warning("tearing down component '%s' with a running task!") self.task_stop(key=key, **kwargs) + self.dev_reset(key=key, **kwargs) del self.devmap[key] + return Reply( + success=True, + msg=f"device {key!r} torn down", + ) @in_devmap - def attrs(self, key: tuple, **kwargs: dict) -> Union[Reply, None]: + def dev_reset(self, key: tuple, **kwargs: dict) -> Reply: + """ + Component reset function. + + Should set the device component into a documented, safe state. This function + is executed at the end of every job. + """ + logger.debug("resetting component '%s'", key) + self.devmap[key].reset() + return Reply( + success=True, + msg=f"component {key!r} reset successfully", + ) + + @in_devmap + def attrs(self, key: tuple, **kwargs: dict) -> Reply: """ Query available :class:`Attrs` on the specified device component. @@ -239,14 +272,12 @@ def attrs(self, key: tuple, **kwargs: dict) -> Union[Reply, None]: ret = self.devmap[key].attrs(**kwargs) return Reply( success=True, - msg=f"attrs of component {key} are: {ret}", + msg=f"attrs of component {key!r} are: {ret}", data=ret, ) @in_devmap - def dev_set_attr( - self, attr: str, val: Any, key: tuple, **kwargs: dict - ) -> Union[Reply, None]: + def dev_set_attr(self, attr: str, val: Any, key: tuple, **kwargs: dict) -> Reply: """ Set value of the :class:`Attr` of the specified device component. @@ -261,7 +292,7 @@ def dev_set_attr( ) @in_devmap - def dev_get_attr(self, attr: str, key: tuple, **kwargs: dict) -> Union[Reply, None]: + def dev_get_attr(self, attr: str, key: tuple, **kwargs: dict) -> Reply: """ Get value of the :class:`Attr` from the specified device component. @@ -277,7 +308,7 @@ def dev_get_attr(self, attr: str, key: tuple, **kwargs: dict) -> Union[Reply, No ) @in_devmap - def dev_status(self, key: tuple, **kwargs: dict) -> Union[Reply, None]: + def dev_status(self, key: tuple, **kwargs: dict) -> Reply: """ Get the status report from the specified device component. @@ -297,7 +328,7 @@ def dev_status(self, key: tuple, **kwargs: dict) -> Union[Reply, None]: ) @in_devmap - def task_start(self, key: tuple, task: Task, **kwargs) -> Union[Reply, None]: + def task_start(self, key: tuple, task: Task, **kwargs) -> Reply: """ Submit a :class:`Task` onto the specified device component. @@ -337,7 +368,7 @@ def task_status(self, key: tuple, **kwargs: dict) -> Reply: return Reply(success=True, msg="ready", data=data) @in_devmap - def task_stop(self, key: tuple, **kwargs) -> Union[Reply, None]: + def task_stop(self, key: tuple, **kwargs) -> Reply: """ Stops a running task and returns any collected data. @@ -354,7 +385,7 @@ def task_stop(self, key: tuple, **kwargs) -> Union[Reply, None]: return Reply(success=True, msg=f"task stopped, {ret.msg}") @in_devmap - def task_data(self, key: tuple, **kwargs) -> Union[Reply, None]: + def task_data(self, key: tuple, **kwargs) -> Reply: """ Return cached task data on the device component and clean the cache. @@ -380,7 +411,7 @@ def task_data(self, key: tuple, **kwargs) -> Union[Reply, None]: ds = Dataset(data_vars=data_vars, coords=uts) return Reply(success=True, msg=f"found {len(data)} new datapoints", data=ds) - def status(self): + def status(self) -> Reply: """ Returns the driver status. Currently that is the names of the components in the `devmap`. @@ -393,24 +424,32 @@ def status(self): data=dict(devkeys=devkeys), ) - def teardown(self): + def reset(self) -> Reply: """ - Tears down the driver. + Resets the driver. Called when the driver process is quitting. Instructs all remaining tasks to - stop. Warns when devices linger. This is not a pass-through to :func:`dev_teardown`. + stop. Warns when devices linger. Passes through to :func:`dev_reset`. This is + not a pass-through to :func:`dev_teardown`. """ for key, dev in self.devmap.items(): - setattr(dev.thread, "do_run", False) - dev.thread.join(1) if dev.thread.is_alive(): - logger.error(f"device {key!r} is still alive") + logger.warning("stopping task on component '%s'", key) + setattr(dev.thread, "do_run", False) + dev.thread.join(timeout=1) + if dev.thread.is_alive(): + logger.error("task on component '%s' is still running", key) else: - logger.debug(f"device {key!r} successfully closed") + logger.debug("component '%s' has no running task", key) + self.dev_reset(key=key) + return Reply( + success=True, + msg="all components on driver have been reset", + ) @in_devmap - def capabilities(self, key: tuple, **kwargs) -> dict: + def capabilities(self, key: tuple, **kwargs) -> Reply: """ Returns the capabilities of the device component. From 58cc38740e9394ccb57db9c4c853ee88ce428d0f Mon Sep 17 00:00:00 2001 From: Peter Kraus Date: Wed, 17 Jul 2024 09:53:55 +0200 Subject: [PATCH 17/17] Fix failing test. --- src/tomato/daemon/driver.py | 6 ++++-- src/tomato/driverinterface_1_0/__init__.py | 2 +- src/tomato/tomato/__init__.py | 19 +++++++++---------- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/tomato/daemon/driver.py b/src/tomato/daemon/driver.py index fc58fc28..2dd3e90f 100644 --- a/src/tomato/daemon/driver.py +++ b/src/tomato/daemon/driver.py @@ -128,7 +128,7 @@ def tomato_driver() -> None: logger.debug(f"{ret=}") return - logger.info(f"driver {args.driver!r} is entering main loop") + logger.info("driver '%s' is entering main loop", args.driver) poller = zmq.Poller() poller.register(rep, zmq.POLLIN) @@ -164,7 +164,9 @@ def tomato_driver() -> None: ) elif hasattr(interface, msg["cmd"]): ret = getattr(interface, msg["cmd"])(**msg["params"]) - logger.debug("replying Reply(success=%s, msg='%s')", ret.success, ret.msg) + else: + logger.critical("unknown command: '%s'", msg["cmd"]) + logger.debug("replying %s", ret) rep.send_pyobj(ret) if status == "stop": break diff --git a/src/tomato/driverinterface_1_0/__init__.py b/src/tomato/driverinterface_1_0/__init__.py index e456ff40..cffb1a51 100644 --- a/src/tomato/driverinterface_1_0/__init__.py +++ b/src/tomato/driverinterface_1_0/__init__.py @@ -223,7 +223,7 @@ def dev_register(self, address: str, channel: int, **kwargs: dict) -> Reply: return Reply( success=True, msg=f"device {key!r} registered", - data=self.devmap[key], + data=key, ) @in_devmap diff --git a/src/tomato/tomato/__init__.py b/src/tomato/tomato/__init__.py index 39bca3e6..b3aaa768 100644 --- a/src/tomato/tomato/__init__.py +++ b/src/tomato/tomato/__init__.py @@ -39,6 +39,7 @@ logger = logging.getLogger(__name__) VERSION = metadata.version("tomato") +MAX_RETRIES = 10 def set_loglevel(delta: int): @@ -48,7 +49,7 @@ def set_loglevel(delta: int): def load_device_file(yamlpath: Path) -> dict: - logger.debug(f"loading device file from '{yamlpath}'") + logger.debug("loading device file from '%s'", yamlpath) try: with yamlpath.open("r") as infile: jsdata = yaml.safe_load(infile) @@ -127,7 +128,7 @@ def status( If ``with_data`` is specified, the state of the daemon will be retrieved. """ - logger.debug(f"checking status of tomato on port {port}") + logger.debug("checking status of tomato on port %d", port) req = context.socket(zmq.REQ) req.connect(f"tcp://127.0.0.1:{port}") req.send_pyobj(dict(cmd="status", with_data=with_data, sender=f"{__name__}.status")) @@ -163,7 +164,7 @@ def start( """ Start the tomato daemon. """ - logger.debug(f"checking for availability of port {port}.") + logger.debug("checking for availability of port %d", port) try: rep = context.socket(zmq.REP) rep.bind(f"tcp://127.0.0.1:{port}") @@ -188,7 +189,7 @@ def start( msg=f"settings file not found in {appdir}, run 'tomato init' to create one", ) - logger.debug(f"starting tomato on port {port}") + logger.debug("starting tomato on port %d", port) cmd = [ "tomato-daemon", "-p", @@ -284,7 +285,7 @@ def init( """ ) if not appdir.exists(): - logger.debug(f"creating directory '{appdir.resolve()}'") + logger.debug("creating directory '%s'", appdir.resolve()) os.makedirs(appdir) with (appdir / "settings.toml").open("w", encoding="utf-8") as of: of.write(defaults) @@ -327,8 +328,6 @@ def reload( if not stat.success: return stat daemon = stat.data - logger.debug(f"{daemon.status=}") - logger.debug(f"{daemon.pips=}") req = context.socket(zmq.REQ) req.connect(f"tcp://127.0.0.1:{port}") if daemon.status == "bootstrap": @@ -343,17 +342,16 @@ def reload( ) ) rep = req.recv_pyobj() - logger.debug(rep) elif daemon.status == "running": retries = 0 while True: if any([drv.port is None for drv in daemon.drvs.values()]): retries += 1 logger.warning("not all tomato-drivers are online yet, waiting") - logger.debug(f"{retries=}") + logger.debug("retry number %d / %d", retries, MAX_RETRIES) time.sleep(timeout / 1000) daemon = status(**kwargs, with_data=True).data - elif retries == 10: + elif retries == MAX_RETRIES: return Reply( success=False, msg="tomato-drivers are not online", data=daemon ) @@ -389,6 +387,7 @@ def reload( capabilities=dev.capabilities, ) logger.debug(f"{params=}") + logger.debug(f"{daemon.drvs[drv.name]}=") ret = _updater( context, daemon.drvs[drv.name].port, "dev_register", params )