Skip to content

Commit f8373fd

Browse files
Session manager Implementation (#7)
This pull request enables swarms to persist their state between runs and pick up where they left off if a stop is required. * NEW: Swarms can now write their state to file as a ".pkl" file * NEW: CLI updated to implement new behaviour. Swarm sessions can be resumed by providing the --resume-session and --swarm-name arguments. * NEW: Sessions are stored in '/swarms/` * NEW: Implemented CLI commands for managing swarms iot-swarm sessions has commands to initialise, list, or remove sessions * NEW: Tests updated for new behaviour. * Swarm.load_swarm(<swarm-id>) can also be used to load a swarm from file
1 parent ebb5185 commit f8373fd

File tree

12 files changed

+735
-174
lines changed

12 files changed

+735
-174
lines changed

pyproject.toml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@ requires = ["setuptools >= 61.0", "autosemver"]
44

55
[project]
66
dependencies = [
7-
"platformdirs",
8-
"boto3",
97
"autosemver",
10-
"config",
11-
"click",
12-
"docutils<0.17",
138
"awscli",
149
"awscrt",
1510
"awsiotsdk",
16-
"oracledb",
1711
"backoff",
12+
"boto3",
13+
"click",
14+
"config",
15+
"dill",
16+
"docutils<0.17",
17+
"oracledb",
1818
"pandas",
19+
"platformdirs",
1920
]
2021
name = "iot-swarm"
2122
dynamic = ["version"]

src/iotswarm/db.py

Lines changed: 63 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pathlib import Path
1313
from math import nan
1414
import sqlite3
15+
from typing import List
1516

1617
logger = logging.getLogger(__name__)
1718

@@ -39,15 +40,18 @@ def __repr__(self):
3940
logger_arg = f"inherit_logger={self._instance_logger.parent}"
4041
return f"{self.__class__.__name__}({logger_arg})"
4142

43+
def __eq__(self, obj):
44+
return self._instance_logger == obj._instance_logger
45+
4246
@abc.abstractmethod
43-
def query_latest_from_site(self):
47+
def query_latest_from_site(self) -> List:
4448
pass
4549

4650

4751
class MockDB(BaseDatabase):
4852

4953
@staticmethod
50-
def query_latest_from_site():
54+
def query_latest_from_site() -> List:
5155
return []
5256

5357

@@ -63,6 +67,14 @@ class CosmosDB(BaseDatabase):
6367
site_id_query: CosmosQuery
6468
"""SQL query for retrieving list of site IDs"""
6569

70+
def __eq__(self, obj):
71+
return (
72+
type(self.connection) == type(obj.connection)
73+
and self.site_data_query == obj.site_data_query
74+
and self.site_id_query == obj.site_id_query
75+
and BaseDatabase.__eq__(self, obj)
76+
)
77+
6678
@staticmethod
6779
def _validate_table(table: CosmosTable) -> None:
6880
"""Validates that the query is legal"""
@@ -99,6 +111,9 @@ def _validate_max_sites(max_sites: int) -> int:
99111

100112
return max_sites
101113

114+
def query_latest_from_site(self):
115+
pass
116+
102117

103118
class Oracle(CosmosDB):
104119
"""Class for handling oracledb logic and retrieving values from DB."""
@@ -225,8 +240,16 @@ class LoopingCsvDB(BaseDatabase):
225240
connection: pd.DataFrame
226241
"""Connection to the pd object holding data."""
227242

228-
cache: dict
229-
"""Cache object containing current index of each site queried."""
243+
db_file: str | Path
244+
"""Path to the database file."""
245+
246+
def __eq__(self, obj):
247+
248+
return (
249+
type(self.connection) == type(obj.connection)
250+
and self.db_file == obj.db_file
251+
and BaseDatabase.__eq__(self, obj)
252+
)
230253

231254
@staticmethod
232255
def _get_connection(*args) -> pd.DataFrame:
@@ -241,27 +264,30 @@ def __init__(self, csv_file: str | Path):
241264
"""
242265

243266
BaseDatabase.__init__(self)
267+
268+
if not isinstance(csv_file, Path):
269+
csv_file = Path(csv_file)
270+
271+
self.db_file = csv_file
244272
self.connection = self._get_connection(csv_file)
245-
self.cache = dict()
246273

247-
def query_latest_from_site(self, site_id: str) -> dict:
274+
def query_latest_from_site(self, site_id: str, index: int) -> dict:
248275
"""Queries the datbase for a `SITE_ID` incrementing by 1 each time called
249276
for a specific site. If the end is reached, it loops back to the start.
250277
251278
Args:
252279
site_id: ID of the site to query for.
280+
index: An offset index to query.
253281
Returns:
254282
A dict of the data row.
255283
"""
256284

257285
data = self.connection.query("SITE_ID == @site_id").replace({nan: None})
258286

259-
if site_id not in self.cache or self.cache[site_id] >= len(data):
260-
self.cache[site_id] = 1
261-
else:
262-
self.cache[site_id] += 1
287+
# Automatically loops back to start
288+
db_index = index % len(data)
263289

264-
return data.iloc[self.cache[site_id] - 1].to_dict()
290+
return data.iloc[db_index].to_dict()
265291

266292
def query_site_ids(self, max_sites: int | None = None) -> list:
267293
"""query_site_ids returns a list of site IDs from the database
@@ -316,32 +342,49 @@ def __init__(self, db_file: str | Path):
316342

317343
self.cursor = self.connection.cursor()
318344

319-
def query_latest_from_site(self, site_id: str, table: CosmosTable) -> dict:
345+
def __eq__(self, obj) -> bool:
346+
return CosmosDB.__eq__(self, obj) and super(LoopingCsvDB, self).__eq__(obj)
347+
348+
def __getstate__(self) -> object:
349+
350+
state = self.__dict__.copy()
351+
352+
del state["connection"]
353+
del state["cursor"]
354+
355+
return state
356+
357+
def __setstate__(self, state) -> object:
358+
359+
self.__dict__.update(state)
360+
361+
self.connection = self._get_connection(self.db_file)
362+
self.cursor = self.connection.cursor()
363+
364+
def query_latest_from_site(
365+
self, site_id: str, table: CosmosTable, index: int
366+
) -> dict:
320367
"""Queries the datbase for a `SITE_ID` incrementing by 1 each time called
321368
for a specific site. If the end is reached, it loops back to the start.
322369
323370
Args:
324371
site_id: ID of the site to query for.
325372
table: A valid table from the database
373+
index: Offset of index.
326374
Returns:
327375
A dict of the data row.
328376
"""
329377
query = self._fill_query(self.site_data_query, table)
330378

331-
if site_id not in self.cache:
332-
self.cache[site_id] = 0
333-
else:
334-
self.cache[site_id] += 1
335-
336379
data = self._query_latest_from_site(
337-
query, {"site_id": site_id, "offset": self.cache[site_id]}
380+
query, {"site_id": site_id, "offset": index}
338381
)
339382

340383
if data is None:
341-
self.cache[site_id] = 0
384+
index = 0
342385

343386
data = self._query_latest_from_site(
344-
query, {"site_id": site_id, "offset": self.cache[site_id]}
387+
query, {"site_id": site_id, "offset": index}
345388
)
346389

347390
return data

src/iotswarm/devices.py

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ class BaseDevice:
5757
mqtt_suffix: str
5858
"""Suffix added to mqtt message."""
5959

60+
swarm: object | None = None
61+
"""The session applied"""
62+
6063
@property
6164
def mqtt_topic(self) -> str:
6265
"Builds the mqtt topic."
@@ -75,6 +78,30 @@ def mqtt_topic(self, value):
7578
self._mqtt_topic = value
7679
self.mqtt_base_topic = value
7780

81+
def __eq__(self, obj) -> bool:
82+
83+
base_equality = (
84+
self.device_type == obj.device_type
85+
and self.cycle == obj.cycle
86+
and self.max_cycles == obj.max_cycles
87+
and self.sleep_time == obj.sleep_time
88+
and self.device_id == obj.device_id
89+
and self.delay_start == obj.delay_start
90+
and self._instance_logger == obj._instance_logger
91+
and self.data_source == obj.data_source
92+
and self.connection == obj.connection
93+
)
94+
95+
table_equality = True
96+
if hasattr(self, "table") and not self.table == obj.table:
97+
table_equality = False
98+
99+
mqtt_equality = True
100+
if hasattr(self, "mqtt_topic") and not self.mqtt_topic == obj.mqtt_topic:
101+
mqtt_equality = False
102+
103+
return base_equality and table_equality and mqtt_equality
104+
78105
def __init__(
79106
self,
80107
device_id: str,
@@ -235,12 +262,19 @@ async def _add_delay(self):
235262
self._instance_logger.debug(f"Delaying first cycle for: {delay}s.")
236263
await asyncio.sleep(delay)
237264

238-
def _send_payload(self, payload: dict):
265+
def _send_payload(self, payload: dict) -> bool:
266+
"""Forwards the payload submission request to the connection
267+
268+
Args:
269+
payload: The data to send.
270+
Returns:
271+
bool: True if sent sucessfully, else false.
272+
"""
239273

240274
if isinstance(self.connection, IotCoreMQTTConnection):
241-
self.connection.send_message(payload, topic=self.mqtt_topic)
275+
return self.connection.send_message(payload, topic=self.mqtt_topic)
242276
else:
243-
self.connection.send_message(payload)
277+
return self.connection.send_message(payload)
244278

245279
async def run(self):
246280
"""The main invocation of the method. Expects a Oracle object to do work on
@@ -251,26 +285,32 @@ async def run(self):
251285
"""
252286

253287
while True:
288+
if self.max_cycles > 0 and self.cycle >= self.max_cycles:
289+
break
254290

255291
if self.delay_start and self.cycle == 0:
256292
await self._add_delay()
257293

258294
payload = await self._get_payload()
259-
payload = self._format_payload(payload)
260295

261-
if payload:
296+
if payload is not None:
297+
payload = self._format_payload(payload)
298+
262299
self._instance_logger.debug("Requesting payload submission.")
263-
self._send_payload(payload)
264-
self._instance_logger.info(
265-
f'Message sent{f" to topic: {self.mqtt_topic}" if self.mqtt_topic else ""}'
266-
)
300+
301+
send_status = self._send_payload(payload)
302+
303+
if send_status == True:
304+
self._instance_logger.info(
305+
f'Message sent{f" to topic: {self.mqtt_topic}" if self.mqtt_topic else ""}'
306+
)
307+
self.cycle += 1
308+
309+
if self.swarm is not None:
310+
self.swarm.write_self(replace=True)
267311
else:
268312
self._instance_logger.warning(f"No data found.")
269313

270-
self.cycle += 1
271-
if self.max_cycles > 0 and self.cycle >= self.max_cycles:
272-
break
273-
274314
await asyncio.sleep(self.sleep_time)
275315

276316
async def _get_payload(self):
@@ -280,16 +320,21 @@ async def _get_payload(self):
280320
self.device_id, self.table
281321
)
282322
elif isinstance(self.data_source, LoopingSQLite3):
283-
return self.data_source.query_latest_from_site(self.device_id, self.table)
323+
return self.data_source.query_latest_from_site(
324+
self.device_id, self.table, self.cycle
325+
)
284326
elif isinstance(self.data_source, LoopingCsvDB):
285-
return self.data_source.query_latest_from_site(self.device_id)
327+
return self.data_source.query_latest_from_site(self.device_id, self.cycle)
286328
elif isinstance(self.data_source, BaseDatabase):
287329
return self.data_source.query_latest_from_site()
288330

289331
def _format_payload(self, payload):
290332
"""Oranises payload into correct structure."""
291333
return payload
292334

335+
def _attach_swarm(self, swarm: object):
336+
self.swarm = swarm
337+
293338

294339
class CR1000XDevice(BaseDevice):
295340
"Represents a CR1000X datalogger."

0 commit comments

Comments
 (0)