diff --git a/src/iotswarm/db.py b/src/iotswarm/db.py index 7a6b65c..4cf7340 100644 --- a/src/iotswarm/db.py +++ b/src/iotswarm/db.py @@ -7,6 +7,7 @@ from iotswarm.queries import CosmosQuery, CosmosSiteQuery import pandas as pd from pathlib import Path +from math import nan logger = logging.getLogger(__name__) @@ -193,7 +194,7 @@ def query_latest_from_site(self, site_id: str) -> dict: A dict of the data row. """ - data = self.connection.query("SITE_ID == @site_id") + data = self.connection.query("SITE_ID == @site_id").replace({nan: None}) if site_id not in self.cache or self.cache[site_id] >= len(data): self.cache[site_id] = 1 @@ -201,3 +202,26 @@ def query_latest_from_site(self, site_id: str) -> dict: self.cache[site_id] += 1 return data.iloc[self.cache[site_id] - 1].to_dict() + + def query_site_ids(self, max_sites: int | None = None) -> list: + """query_site_ids returns a list of site IDs from the database + + Args: + max_sites: Maximum number of sites to retreive + + Returns: + List[str]: A list of site ID strings. + """ + if max_sites is not None: + max_sites = int(max_sites) + if max_sites < 0: + raise ValueError( + f"`max_sites` must be 1 or more, or 0 for no maximum. Received: {max_sites}" + ) + + sites = self.connection["SITE_ID"].drop_duplicates().to_list() + + if max_sites is not None and max_sites > 0: + sites = sites[:max_sites] + + return sites diff --git a/src/iotswarm/devices.py b/src/iotswarm/devices.py index 8ae240e..c128e0a 100644 --- a/src/iotswarm/devices.py +++ b/src/iotswarm/devices.py @@ -275,12 +275,10 @@ async def _get_payload(self): return await self.data_source.query_latest_from_site( self.device_id, self.query ) - - elif isinstance(self.data_source, BaseDatabase): - return self.data_source.query_latest_from_site() - elif isinstance(self.data_source, LoopingCsvDB): return self.data_source.query_latest_from_site(self.device_id) + elif isinstance(self.data_source, BaseDatabase): + return self.data_source.query_latest_from_site() def _format_payload(self, payload): """Oranises payload into correct structure.""" diff --git a/src/iotswarm/scripts/cli.py b/src/iotswarm/scripts/cli.py index 2517545..7ffe8ac 100644 --- a/src/iotswarm/scripts/cli.py +++ b/src/iotswarm/scripts/cli.py @@ -4,7 +4,7 @@ from iotswarm import queries from iotswarm.devices import BaseDevice, CR1000XDevice from iotswarm.swarm import Swarm -from iotswarm.db import Oracle +from iotswarm.db import Oracle, LoopingCsvDB from iotswarm.messaging.core import MockMessageConnection from iotswarm.messaging.aws import IotCoreMQTTConnection import asyncio @@ -161,88 +161,110 @@ async def _list_sites(ctx, query): click.echo(asyncio.run(_list_sites(ctx, query))) +def common_device_options(function): + click.option( + "--sleep-time", + type=click.INT, + help="The number of seconds each site goes idle after sending a message.", + )(function) + + click.option( + "--max-cycles", + type=click.IntRange(0), + help="Maximum number message sending cycles. Runs forever if set to 0.", + )(function) + + click.option( + "--max-sites", + type=click.IntRange(0), + help="Maximum number of sites allowed to initialize. No limit if set to 0.", + )(function) + + click.option( + "--swarm-name", + type=click.STRING, + help="Name given to swarm. Appears in the logs.", + )(function) + + click.option( + "--delay-start", + is_flag=True, + default=False, + help="Adds a random delay before the first message from each site up to `--sleep-time`.", + )(function) + + click.option( + "--device-type", type=click.Choice(["basic", "cr1000x"]), default="basic" + )(function) + + return function + + +def common_iotcore_options(function): + click.argument( + "client-id", + type=click.STRING, + required=True, + )(function) + + click.option( + "--endpoint", + type=click.STRING, + required=True, + envvar="IOT_SWARM_MQTT_ENDPOINT", + help="Endpoint of the MQTT receiving host.", + )(function) + + click.option( + "--cert-path", + type=click.Path(exists=True), + required=True, + envvar="IOT_SWARM_MQTT_CERT_PATH", + help="Path to public key certificate for the device. Must match key assigned to the `--client-id` in the cloud provider.", + )(function) + + click.option( + "--key-path", + type=click.Path(exists=True), + required=True, + envvar="IOT_SWARM_MQTT_KEY_PATH", + help="Path to the private key that pairs with the `--cert-path`.", + )(function) + + click.option( + "--ca-cert-path", + type=click.Path(exists=True), + required=True, + envvar="IOT_SWARM_MQTT_CA_CERT_PATH", + help="Path to the root Certificate Authority (CA) for the MQTT host.", + )(function) + + click.option( + "--mqtt-prefix", + type=click.STRING, + help="Prefixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.", + )(function) + + click.option( + "--mqtt-suffix", + type=click.STRING, + help="Suffixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.", + )(function) + + return function + + @cosmos.command() @click.pass_context -@click.argument( - "provider", - type=click.Choice(["aws"]), -) @click.argument( "query", type=click.Choice(TABLES), ) -@click.argument( - "client-id", - type=click.STRING, - required=True, -) -@click.option( - "--endpoint", - type=click.STRING, - required=True, - envvar="IOT_SWARM_MQTT_ENDPOINT", - help="Endpoint of the MQTT receiving host.", -) -@click.option( - "--cert-path", - type=click.Path(exists=True), - required=True, - envvar="IOT_SWARM_MQTT_CERT_PATH", - help="Path to public key certificate for the device. Must match key assigned to the `--client-id` in the cloud provider.", -) -@click.option( - "--key-path", - type=click.Path(exists=True), - required=True, - envvar="IOT_SWARM_MQTT_KEY_PATH", - help="Path to the private key that pairs with the `--cert-path`.", -) -@click.option( - "--ca-cert-path", - type=click.Path(exists=True), - required=True, - envvar="IOT_SWARM_MQTT_CA_CERT_PATH", - help="Path to the root Certificate Authority (CA) for the MQTT host.", -) -@click.option( - "--sleep-time", - type=click.INT, - help="The number of seconds each site goes idle after sending a message.", -) -@click.option( - "--max-cycles", - type=click.IntRange(0), - help="Maximum number message sending cycles. Runs forever if set to 0.", -) -@click.option( - "--max-sites", - type=click.IntRange(0), - help="Maximum number of sites allowed to initialize. No limit if set to 0.", -) -@click.option( - "--swarm-name", type=click.STRING, help="Name given to swarm. Appears in the logs." -) -@click.option( - "--delay-start", - is_flag=True, - default=False, - help="Adds a random delay before the first message from each site up to `--sleep-time`.", -) -@click.option( - "--mqtt-prefix", - type=click.STRING, - help="Prefixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.", -) -@click.option( - "--mqtt-suffix", - type=click.STRING, - help="Suffixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.", -) +@common_device_options +@common_iotcore_options @click.option("--dry", is_flag=True, default=False, help="Doesn't send out any data.") -@click.option("--device-type", type=click.Choice(["basic", "cr1000x"]), default="basic") def mqtt( ctx, - provider, query, endpoint, cert_path, @@ -259,7 +281,7 @@ def mqtt( dry, device_type, ): - """Sends The cosmos data via MQTT protocol using PROVIDER. + """Sends The cosmos data via MQTT protocol using IoT Core. Data is collected from the db using QUERY and sent using CLIENT_ID. Currently only supports sending through AWS IoT Core.""" @@ -280,7 +302,7 @@ async def _mqtt(): if dry == True: connection = MockMessageConnection() - elif provider == "aws": + else: connection = IotCoreMQTTConnection( endpoint=endpoint, cert_path=cert_path, @@ -316,5 +338,113 @@ async def _mqtt(): asyncio.run(_mqtt()) +@main.group() +@click.pass_context +@click.option( + "--site", + type=click.STRING, + multiple=True, + help="Adds a site to be initialized. Can be invoked multiple times for other sites." + " Grabs all sites from database query if none provided", +) +@click.option( + "--file", + type=click.Path(exists=True), + required=True, + envvar="IOT_SWARM_CSV_DB", + help="*.csv file used to instantiate a pandas database.", +) +def looping_csv(ctx, site, file): + """Instantiates a pandas dataframe from a csv file which is used as the database. + Responsibility falls on the user to ensure the correct file is selected.""" + + ctx.obj["db"] = LoopingCsvDB(file) + ctx.obj["sites"] = site + + +looping_csv.add_command(test) + + +@looping_csv.command +@click.pass_context +@click.option("--max-sites", type=click.IntRange(min=0), default=0) +def list_sites(ctx, max_sites): + """Prints the sites present in database.""" + + sites = ctx.obj["db"].query_site_ids(max_sites=max_sites) + click.echo(sites) + + +@looping_csv.command() +@click.pass_context +@common_device_options +@common_iotcore_options +@click.option("--dry", is_flag=True, default=False, help="Doesn't send out any data.") +def mqtt( + ctx, + endpoint, + cert_path, + key_path, + ca_cert_path, + client_id, + sleep_time, + max_cycles, + max_sites, + swarm_name, + delay_start, + mqtt_prefix, + mqtt_suffix, + dry, + device_type, +): + """Sends The cosmos data via MQTT protocol using IoT Core. + Data is collected from the db using QUERY and sent using CLIENT_ID. + + Currently only supports sending through AWS IoT Core.""" + + async def _mqtt(): + + sites = ctx.obj["sites"] + db = ctx.obj["db"] + if len(sites) == 0: + sites = db.query_site_ids(max_sites=max_sites) + + if dry == True: + connection = MockMessageConnection() + else: + connection = IotCoreMQTTConnection( + endpoint=endpoint, + cert_path=cert_path, + key_path=key_path, + ca_cert_path=ca_cert_path, + client_id=client_id, + ) + + if device_type == "basic": + DeviceClass = BaseDevice + elif device_type == "cr1000x": + DeviceClass = CR1000XDevice + + site_devices = [ + DeviceClass( + site, + db, + connection, + sleep_time=sleep_time, + max_cycles=max_cycles, + delay_start=delay_start, + mqtt_prefix=mqtt_prefix, + mqtt_suffix=mqtt_suffix, + ) + for site in sites + ] + + swarm = Swarm(site_devices, swarm_name) + + await swarm.run() + + asyncio.run(_mqtt()) + + if __name__ == "__main__": main(auto_envvar_prefix="IOT_SWARM", obj={}) diff --git a/src/tests/test_db.py b/src/tests/test_db.py index b2d24c5..d24a0bf 100644 --- a/src/tests/test_db.py +++ b/src/tests/test_db.py @@ -308,6 +308,30 @@ def test_cache_counter_restarts_at_end(self): self.assertEqual(len(expected), len(data)) + @data_files_exist + def test_site_ids_can_be_retrieved(self): + database = db.LoopingCsvDB(self.data_path["LEVEL1_SOILMET_30MIN"]) + + site_ids_full = database.query_site_ids() + site_ids_exp_full = database.query_site_ids(max_sites=0) + + + self.assertIsInstance(site_ids_full, list) + + self.assertGreater(len(site_ids_full), 0) + for site in site_ids_full: + self.assertIsInstance(site, str) + + self.assertEqual(len(site_ids_full), len(site_ids_exp_full)) + + site_ids_limit = database.query_site_ids(max_sites=5) + + self.assertEqual(len(site_ids_limit), 5) + self.assertGreater(len(site_ids_full), len(site_ids_limit)) + + with self.assertRaises(ValueError): + + database.query_site_ids(max_sites=-1) class TestLoopingCsvDBEndToEnd(unittest.IsolatedAsyncioTestCase): """Tests the LoopingCsvDB class."""