12
12
from pathlib import Path
13
13
from math import nan
14
14
import sqlite3
15
+ from typing import List
15
16
16
17
logger = logging .getLogger (__name__ )
17
18
@@ -39,15 +40,18 @@ def __repr__(self):
39
40
logger_arg = f"inherit_logger={ self ._instance_logger .parent } "
40
41
return f"{ self .__class__ .__name__ } ({ logger_arg } )"
41
42
43
+ def __eq__ (self , obj ):
44
+ return self ._instance_logger == obj ._instance_logger
45
+
42
46
@abc .abstractmethod
43
- def query_latest_from_site (self ):
47
+ def query_latest_from_site (self ) -> List :
44
48
pass
45
49
46
50
47
51
class MockDB (BaseDatabase ):
48
52
49
53
@staticmethod
50
- def query_latest_from_site ():
54
+ def query_latest_from_site () -> List :
51
55
return []
52
56
53
57
@@ -63,6 +67,14 @@ class CosmosDB(BaseDatabase):
63
67
site_id_query : CosmosQuery
64
68
"""SQL query for retrieving list of site IDs"""
65
69
70
+ def __eq__ (self , obj ):
71
+ return (
72
+ type (self .connection ) == type (obj .connection )
73
+ and self .site_data_query == obj .site_data_query
74
+ and self .site_id_query == obj .site_id_query
75
+ and BaseDatabase .__eq__ (self , obj )
76
+ )
77
+
66
78
@staticmethod
67
79
def _validate_table (table : CosmosTable ) -> None :
68
80
"""Validates that the query is legal"""
@@ -99,6 +111,9 @@ def _validate_max_sites(max_sites: int) -> int:
99
111
100
112
return max_sites
101
113
114
+ def query_latest_from_site (self ):
115
+ pass
116
+
102
117
103
118
class Oracle (CosmosDB ):
104
119
"""Class for handling oracledb logic and retrieving values from DB."""
@@ -225,8 +240,16 @@ class LoopingCsvDB(BaseDatabase):
225
240
connection : pd .DataFrame
226
241
"""Connection to the pd object holding data."""
227
242
228
- cache : dict
229
- """Cache object containing current index of each site queried."""
243
+ db_file : str | Path
244
+ """Path to the database file."""
245
+
246
+ def __eq__ (self , obj ):
247
+
248
+ return (
249
+ type (self .connection ) == type (obj .connection )
250
+ and self .db_file == obj .db_file
251
+ and BaseDatabase .__eq__ (self , obj )
252
+ )
230
253
231
254
@staticmethod
232
255
def _get_connection (* args ) -> pd .DataFrame :
@@ -241,27 +264,30 @@ def __init__(self, csv_file: str | Path):
241
264
"""
242
265
243
266
BaseDatabase .__init__ (self )
267
+
268
+ if not isinstance (csv_file , Path ):
269
+ csv_file = Path (csv_file )
270
+
271
+ self .db_file = csv_file
244
272
self .connection = self ._get_connection (csv_file )
245
- self .cache = dict ()
246
273
247
- def query_latest_from_site (self , site_id : str ) -> dict :
274
+ def query_latest_from_site (self , site_id : str , index : int ) -> dict :
248
275
"""Queries the datbase for a `SITE_ID` incrementing by 1 each time called
249
276
for a specific site. If the end is reached, it loops back to the start.
250
277
251
278
Args:
252
279
site_id: ID of the site to query for.
280
+ index: An offset index to query.
253
281
Returns:
254
282
A dict of the data row.
255
283
"""
256
284
257
285
data = self .connection .query ("SITE_ID == @site_id" ).replace ({nan : None })
258
286
259
- if site_id not in self .cache or self .cache [site_id ] >= len (data ):
260
- self .cache [site_id ] = 1
261
- else :
262
- self .cache [site_id ] += 1
287
+ # Automatically loops back to start
288
+ db_index = index % len (data )
263
289
264
- return data .iloc [self . cache [ site_id ] - 1 ].to_dict ()
290
+ return data .iloc [db_index ].to_dict ()
265
291
266
292
def query_site_ids (self , max_sites : int | None = None ) -> list :
267
293
"""query_site_ids returns a list of site IDs from the database
@@ -316,32 +342,49 @@ def __init__(self, db_file: str | Path):
316
342
317
343
self .cursor = self .connection .cursor ()
318
344
319
- def query_latest_from_site (self , site_id : str , table : CosmosTable ) -> dict :
345
+ def __eq__ (self , obj ) -> bool :
346
+ return CosmosDB .__eq__ (self , obj ) and super (LoopingCsvDB , self ).__eq__ (obj )
347
+
348
+ def __getstate__ (self ) -> object :
349
+
350
+ state = self .__dict__ .copy ()
351
+
352
+ del state ["connection" ]
353
+ del state ["cursor" ]
354
+
355
+ return state
356
+
357
+ def __setstate__ (self , state ) -> object :
358
+
359
+ self .__dict__ .update (state )
360
+
361
+ self .connection = self ._get_connection (self .db_file )
362
+ self .cursor = self .connection .cursor ()
363
+
364
+ def query_latest_from_site (
365
+ self , site_id : str , table : CosmosTable , index : int
366
+ ) -> dict :
320
367
"""Queries the datbase for a `SITE_ID` incrementing by 1 each time called
321
368
for a specific site. If the end is reached, it loops back to the start.
322
369
323
370
Args:
324
371
site_id: ID of the site to query for.
325
372
table: A valid table from the database
373
+ index: Offset of index.
326
374
Returns:
327
375
A dict of the data row.
328
376
"""
329
377
query = self ._fill_query (self .site_data_query , table )
330
378
331
- if site_id not in self .cache :
332
- self .cache [site_id ] = 0
333
- else :
334
- self .cache [site_id ] += 1
335
-
336
379
data = self ._query_latest_from_site (
337
- query , {"site_id" : site_id , "offset" : self . cache [ site_id ] }
380
+ query , {"site_id" : site_id , "offset" : index }
338
381
)
339
382
340
383
if data is None :
341
- self . cache [ site_id ] = 0
384
+ index = 0
342
385
343
386
data = self ._query_latest_from_site (
344
- query , {"site_id" : site_id , "offset" : self . cache [ site_id ] }
387
+ query , {"site_id" : site_id , "offset" : index }
345
388
)
346
389
347
390
return data
0 commit comments