diff --git a/pyproject.toml b/pyproject.toml index f5e86c6..a57aaa2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,4 +48,4 @@ markers = [ ] [tool.coverage.run] -omit = ["*example.py", "*__init__.py", "queries.py", "loggers.py"] +omit = ["*example.py", "*__init__.py", "queries.py", "loggers.py", "cli.py"] diff --git a/src/iotdevicesimulator/db.py b/src/iotdevicesimulator/db.py index c2bb1e6..e0512c2 100644 --- a/src/iotdevicesimulator/db.py +++ b/src/iotdevicesimulator/db.py @@ -147,7 +147,10 @@ async def query_site_ids( await cursor.execute(query.value) data = await cursor.fetchall() - data = [x[0] for x in data[:max_sites]] + if max_sites == 0: + data = [x[0] for x in data] + else: + data = [x[0] for x in data[:max_sites]] if not data: data = [] diff --git a/src/iotdevicesimulator/example.py b/src/iotdevicesimulator/example.py index 41ba775..663b44f 100644 --- a/src/iotdevicesimulator/example.py +++ b/src/iotdevicesimulator/example.py @@ -44,23 +44,15 @@ async def main(config_path: str): client_id="fdri_swarm", ) - device = devices.DeviceFactory.create_device( - MockMessageConnection(), - data_source, - query, - "site_id", - device_type="cr1000x", - ) - - # devices = [ - # MQTTCosmosDevice( - # query, site, data_source, mqtt_connection, topic_prefix="fdri/cosmos_sites" - # ) - # for site in device_ids - # ] - - # swarm = Swarm(devices, name="soilmet") - # await swarm.run() + device_objs = [ + devices.CR1000XDevice( + site, data_source, mqtt_connection, query=query, sleep_time=5 + ) + for site in device_ids + ] + + swarm = Swarm(device_objs, name="soilmet") + await swarm.run() if __name__ == "__main__": diff --git a/src/iotdevicesimulator/scripts/cli.py b/src/iotdevicesimulator/scripts/cli.py index aae57e3..fd01e94 100644 --- a/src/iotdevicesimulator/scripts/cli.py +++ b/src/iotdevicesimulator/scripts/cli.py @@ -2,7 +2,9 @@ import click from iotdevicesimulator import queries -from iotdevicesimulator.swarm import CosmosSwarm +from iotdevicesimulator.devices import BaseDevice, CR1000XDevice +from iotdevicesimulator.swarm import Swarm +from iotdevicesimulator.db import Oracle, MockDB from iotdevicesimulator.messaging.core import MockMessageConnection from iotdevicesimulator.messaging.aws import IotCoreMQTTConnection import asyncio @@ -12,6 +14,13 @@ TABLES = [table.name for table in queries.CosmosQuery] +@click.command +@click.pass_context +def test(ctx: click.Context): + """Enables testing of cosmos group arguments.""" + print(ctx.obj) + + @click.group() @click.pass_context @click.option( @@ -29,6 +38,9 @@ def main(ctx: click.Context, log_config: Path): logging.config.fileConfig(fname=log_config) +main.add_command(test) + + @main.group() @click.pass_context @click.option( @@ -66,22 +78,24 @@ def cosmos(ctx: click.Context, site: str, dsn: str, user: str, password: str): ctx.obj["sites"] = site -@cosmos.command() -def test(): - """Enables testing of cosmos group arguments.""" - pass +cosmos.add_command(test) @cosmos.command() @click.pass_context @click.argument("query", type=click.Choice(TABLES)) -def list_sites(ctx, query): +@click.option("--max-sites", type=click.IntRange(min=0), default=0) +def list_sites(ctx, query, max_sites): """Lists site IDs from the database from table QUERY.""" async def _list_sites(ctx, query): - oracle = await CosmosSwarm._get_oracle(ctx.obj["credentials"]) - sites = await CosmosSwarm._get_sites_from_db( - oracle, queries.CosmosSiteQuery[query] + oracle = await Oracle.create( + dsn=ctx.obj["credentials"]["dsn"], + user=ctx.obj["credentials"]["user"], + password=ctx.obj["credentials"]["password"], + ) + sites = await oracle.query_site_ids( + queries.CosmosSiteQuery[query], max_sites=max_sites ) return sites @@ -156,17 +170,17 @@ async def _list_sites(ctx, query): help="Adds a random delay before the first message from each site up to `--sleep-time`.", ) @click.option( - "--topic-prefix", + "--mqtt-prefix", type=click.STRING, - default="fdri/cosmos_test", help="Prefixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.", ) @click.option( - "--topic-suffix", + "--mqtt-suffix", type=click.STRING, help="Suffixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.", ) @click.option("--dry", is_flag=True, default=False, help="Doesn't send out any data.") +@click.option("--device-type", type=click.Choice(["basic", "cr1000x"]), default="basic") def mqtt( ctx, provider, @@ -181,49 +195,66 @@ def mqtt( max_sites, swarm_name, delay_start, - topic_prefix, - topic_suffix, + mqtt_prefix, + mqtt_suffix, dry, + device_type, ): """Sends The cosmos data via MQTT protocol using PROVIDER. Data is collected from the db using QUERY and sent using CLIENT_ID. Currently only supports sending through AWS IoT Core.""" - query = queries.CosmosQuery[query] - async def _swarm(query, mqtt_connection, credentials, *args, **kwargs): - swarm = await CosmosSwarm.create( - query, mqtt_connection, credentials, *args, **kwargs + async def _mqtt(): + oracle = await Oracle.create( + dsn=ctx.obj["credentials"]["dsn"], + user=ctx.obj["credentials"]["user"], + password=ctx.obj["credentials"]["password"], ) - await swarm.run() + data_query = queries.CosmosQuery[query] + site_query = queries.CosmosSiteQuery[query] - if dry == True: - connection = MockMessageConnection() - elif provider == "aws": - connection = IotCoreMQTTConnection( - endpoint=endpoint, - cert_path=cert_path, - key_path=key_path, - ca_cert_path=ca_cert_path, - client_id=client_id, - ) + sites = ctx.obj["sites"] + if len(sites) == 0: + sites = await oracle.query_site_ids(site_query, max_sites=max_sites) - asyncio.run( - _swarm( - query, - connection, - ctx.obj["credentials"], - site_ids=ctx.obj["sites"], - sleep_time=sleep_time, - max_cycles=max_cycles, - max_sites=max_sites, - swarm_name=swarm_name, - delay_start=delay_start, - topic_prefix=topic_prefix, - topic_suffix=topic_suffix, - ) - ) + if dry == True: + connection = MockMessageConnection() + elif provider == "aws": + connection = IotCoreMQTTConnection( + endpoint=endpoint, + cert_path=cert_path, + key_path=key_path, + ca_cert_path=ca_cert_path, + client_id=client_id, + ) + + if device_type == "basic": + DeviceClass = BaseDevice + elif device_type == "cr1000x": + DeviceClass = CR1000XDevice + + site_devices = [ + DeviceClass( + site, + oracle, + connection, + sleep_time=sleep_time, + query=data_query, + max_cycles=max_cycles, + delay_start=delay_start, + mqtt_prefix=mqtt_prefix, + mqtt_suffix=mqtt_suffix, + ) + for site in sites + ] + + swarm = Swarm(site_devices, swarm_name) + + await swarm.run() + + asyncio.run(_mqtt()) if __name__ == "__main__":