Skip to content

Commit

Permalink
Updated CLI to allow selection of device type
Browse files Browse the repository at this point in the history
  • Loading branch information
lewis-chambers committed Jun 10, 2024
1 parent 8a9e56f commit a0e249c
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 63 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ markers = [
]

[tool.coverage.run]
omit = ["*example.py", "*__init__.py", "queries.py", "loggers.py"]
omit = ["*example.py", "*__init__.py", "queries.py", "loggers.py", "cli.py"]
5 changes: 4 additions & 1 deletion src/iotdevicesimulator/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ async def query_site_ids(
await cursor.execute(query.value)

data = await cursor.fetchall()
data = [x[0] for x in data[:max_sites]]
if max_sites == 0:
data = [x[0] for x in data]
else:
data = [x[0] for x in data[:max_sites]]

if not data:
data = []
Expand Down
26 changes: 9 additions & 17 deletions src/iotdevicesimulator/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,15 @@ async def main(config_path: str):
client_id="fdri_swarm",
)

device = devices.DeviceFactory.create_device(
MockMessageConnection(),
data_source,
query,
"site_id",
device_type="cr1000x",
)

# devices = [
# MQTTCosmosDevice(
# query, site, data_source, mqtt_connection, topic_prefix="fdri/cosmos_sites"
# )
# for site in device_ids
# ]

# swarm = Swarm(devices, name="soilmet")
# await swarm.run()
device_objs = [
devices.CR1000XDevice(
site, data_source, mqtt_connection, query=query, sleep_time=5
)
for site in device_ids
]

swarm = Swarm(device_objs, name="soilmet")
await swarm.run()


if __name__ == "__main__":
Expand Down
119 changes: 75 additions & 44 deletions src/iotdevicesimulator/scripts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import click
from iotdevicesimulator import queries
from iotdevicesimulator.swarm import CosmosSwarm
from iotdevicesimulator.devices import BaseDevice, CR1000XDevice
from iotdevicesimulator.swarm import Swarm
from iotdevicesimulator.db import Oracle, MockDB
from iotdevicesimulator.messaging.core import MockMessageConnection
from iotdevicesimulator.messaging.aws import IotCoreMQTTConnection
import asyncio
Expand All @@ -12,6 +14,13 @@
TABLES = [table.name for table in queries.CosmosQuery]


@click.command
@click.pass_context
def test(ctx: click.Context):
"""Enables testing of cosmos group arguments."""
print(ctx.obj)


@click.group()
@click.pass_context
@click.option(
Expand All @@ -29,6 +38,9 @@ def main(ctx: click.Context, log_config: Path):
logging.config.fileConfig(fname=log_config)


main.add_command(test)


@main.group()
@click.pass_context
@click.option(
Expand Down Expand Up @@ -66,22 +78,24 @@ def cosmos(ctx: click.Context, site: str, dsn: str, user: str, password: str):
ctx.obj["sites"] = site


@cosmos.command()
def test():
"""Enables testing of cosmos group arguments."""
pass
cosmos.add_command(test)


@cosmos.command()
@click.pass_context
@click.argument("query", type=click.Choice(TABLES))
def list_sites(ctx, query):
@click.option("--max-sites", type=click.IntRange(min=0), default=0)
def list_sites(ctx, query, max_sites):
"""Lists site IDs from the database from table QUERY."""

async def _list_sites(ctx, query):
oracle = await CosmosSwarm._get_oracle(ctx.obj["credentials"])
sites = await CosmosSwarm._get_sites_from_db(
oracle, queries.CosmosSiteQuery[query]
oracle = await Oracle.create(
dsn=ctx.obj["credentials"]["dsn"],
user=ctx.obj["credentials"]["user"],
password=ctx.obj["credentials"]["password"],
)
sites = await oracle.query_site_ids(
queries.CosmosSiteQuery[query], max_sites=max_sites
)
return sites

Expand Down Expand Up @@ -156,17 +170,17 @@ async def _list_sites(ctx, query):
help="Adds a random delay before the first message from each site up to `--sleep-time`.",
)
@click.option(
"--topic-prefix",
"--mqtt-prefix",
type=click.STRING,
default="fdri/cosmos_test",
help="Prefixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.",
)
@click.option(
"--topic-suffix",
"--mqtt-suffix",
type=click.STRING,
help="Suffixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.",
)
@click.option("--dry", is_flag=True, default=False, help="Doesn't send out any data.")
@click.option("--device-type", type=click.Choice(["basic", "cr1000x"]), default="basic")
def mqtt(
ctx,
provider,
Expand All @@ -181,49 +195,66 @@ def mqtt(
max_sites,
swarm_name,
delay_start,
topic_prefix,
topic_suffix,
mqtt_prefix,
mqtt_suffix,
dry,
device_type,
):
"""Sends The cosmos data via MQTT protocol using PROVIDER.
Data is collected from the db using QUERY and sent using CLIENT_ID.
Currently only supports sending through AWS IoT Core."""
query = queries.CosmosQuery[query]

async def _swarm(query, mqtt_connection, credentials, *args, **kwargs):
swarm = await CosmosSwarm.create(
query, mqtt_connection, credentials, *args, **kwargs
async def _mqtt():
oracle = await Oracle.create(
dsn=ctx.obj["credentials"]["dsn"],
user=ctx.obj["credentials"]["user"],
password=ctx.obj["credentials"]["password"],
)

await swarm.run()
data_query = queries.CosmosQuery[query]
site_query = queries.CosmosSiteQuery[query]

if dry == True:
connection = MockMessageConnection()
elif provider == "aws":
connection = IotCoreMQTTConnection(
endpoint=endpoint,
cert_path=cert_path,
key_path=key_path,
ca_cert_path=ca_cert_path,
client_id=client_id,
)
sites = ctx.obj["sites"]
if len(sites) == 0:
sites = await oracle.query_site_ids(site_query, max_sites=max_sites)

asyncio.run(
_swarm(
query,
connection,
ctx.obj["credentials"],
site_ids=ctx.obj["sites"],
sleep_time=sleep_time,
max_cycles=max_cycles,
max_sites=max_sites,
swarm_name=swarm_name,
delay_start=delay_start,
topic_prefix=topic_prefix,
topic_suffix=topic_suffix,
)
)
if dry == True:
connection = MockMessageConnection()
elif provider == "aws":
connection = IotCoreMQTTConnection(
endpoint=endpoint,
cert_path=cert_path,
key_path=key_path,
ca_cert_path=ca_cert_path,
client_id=client_id,
)

if device_type == "basic":
DeviceClass = BaseDevice
elif device_type == "cr1000x":
DeviceClass = CR1000XDevice

site_devices = [
DeviceClass(
site,
oracle,
connection,
sleep_time=sleep_time,
query=data_query,
max_cycles=max_cycles,
delay_start=delay_start,
mqtt_prefix=mqtt_prefix,
mqtt_suffix=mqtt_suffix,
)
for site in sites
]

swarm = Swarm(site_devices, swarm_name)

await swarm.run()

asyncio.run(_mqtt())


if __name__ == "__main__":
Expand Down

0 comments on commit a0e249c

Please sign in to comment.