Skip to content

Commit

Permalink
Updated to allow CLI to call looped database
Browse files Browse the repository at this point in the history
  • Loading branch information
lewis-chambers committed Jun 17, 2024
1 parent e0983da commit b234463
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 81 deletions.
26 changes: 25 additions & 1 deletion src/iotswarm/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from iotswarm.queries import CosmosQuery, CosmosSiteQuery
import pandas as pd
from pathlib import Path
from math import nan

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -193,11 +194,34 @@ def query_latest_from_site(self, site_id: str) -> dict:
A dict of the data row.
"""

data = self.connection.query("SITE_ID == @site_id")
data = self.connection.query("SITE_ID == @site_id").replace({nan: None})

if site_id not in self.cache or self.cache[site_id] >= len(data):
self.cache[site_id] = 1
else:
self.cache[site_id] += 1

return data.iloc[self.cache[site_id] - 1].to_dict()

def query_site_ids(self, max_sites: int | None = None) -> list:
"""query_site_ids returns a list of site IDs from the database
Args:
max_sites: Maximum number of sites to retreive
Returns:
List[str]: A list of site ID strings.
"""
if max_sites is not None:
max_sites = int(max_sites)
if max_sites < 0:
raise ValueError(
f"`max_sites` must be 1 or more, or 0 for no maximum. Received: {max_sites}"
)

sites = self.connection["SITE_ID"].drop_duplicates().to_list()

if max_sites is not None and max_sites > 0:
sites = sites[:max_sites]

return sites
6 changes: 2 additions & 4 deletions src/iotswarm/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,10 @@ async def _get_payload(self):
return await self.data_source.query_latest_from_site(
self.device_id, self.query
)

elif isinstance(self.data_source, BaseDatabase):
return self.data_source.query_latest_from_site()

elif isinstance(self.data_source, LoopingCsvDB):
return self.data_source.query_latest_from_site(self.device_id)
elif isinstance(self.data_source, BaseDatabase):
return self.data_source.query_latest_from_site()

def _format_payload(self, payload):
"""Oranises payload into correct structure."""
Expand Down
282 changes: 206 additions & 76 deletions src/iotswarm/scripts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from iotswarm import queries
from iotswarm.devices import BaseDevice, CR1000XDevice
from iotswarm.swarm import Swarm
from iotswarm.db import Oracle
from iotswarm.db import Oracle, LoopingCsvDB
from iotswarm.messaging.core import MockMessageConnection
from iotswarm.messaging.aws import IotCoreMQTTConnection
import asyncio
Expand Down Expand Up @@ -161,88 +161,110 @@ async def _list_sites(ctx, query):
click.echo(asyncio.run(_list_sites(ctx, query)))


def common_device_options(function):
click.option(
"--sleep-time",
type=click.INT,
help="The number of seconds each site goes idle after sending a message.",
)(function)

click.option(
"--max-cycles",
type=click.IntRange(0),
help="Maximum number message sending cycles. Runs forever if set to 0.",
)(function)

click.option(
"--max-sites",
type=click.IntRange(0),
help="Maximum number of sites allowed to initialize. No limit if set to 0.",
)(function)

click.option(
"--swarm-name",
type=click.STRING,
help="Name given to swarm. Appears in the logs.",
)(function)

click.option(
"--delay-start",
is_flag=True,
default=False,
help="Adds a random delay before the first message from each site up to `--sleep-time`.",
)(function)

click.option(
"--device-type", type=click.Choice(["basic", "cr1000x"]), default="basic"
)(function)

return function


def common_iotcore_options(function):
click.argument(
"client-id",
type=click.STRING,
required=True,
)(function)

click.option(
"--endpoint",
type=click.STRING,
required=True,
envvar="IOT_SWARM_MQTT_ENDPOINT",
help="Endpoint of the MQTT receiving host.",
)(function)

click.option(
"--cert-path",
type=click.Path(exists=True),
required=True,
envvar="IOT_SWARM_MQTT_CERT_PATH",
help="Path to public key certificate for the device. Must match key assigned to the `--client-id` in the cloud provider.",
)(function)

click.option(
"--key-path",
type=click.Path(exists=True),
required=True,
envvar="IOT_SWARM_MQTT_KEY_PATH",
help="Path to the private key that pairs with the `--cert-path`.",
)(function)

click.option(
"--ca-cert-path",
type=click.Path(exists=True),
required=True,
envvar="IOT_SWARM_MQTT_CA_CERT_PATH",
help="Path to the root Certificate Authority (CA) for the MQTT host.",
)(function)

click.option(
"--mqtt-prefix",
type=click.STRING,
help="Prefixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.",
)(function)

click.option(
"--mqtt-suffix",
type=click.STRING,
help="Suffixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.",
)(function)

return function


@cosmos.command()
@click.pass_context
@click.argument(
"provider",
type=click.Choice(["aws"]),
)
@click.argument(
"query",
type=click.Choice(TABLES),
)
@click.argument(
"client-id",
type=click.STRING,
required=True,
)
@click.option(
"--endpoint",
type=click.STRING,
required=True,
envvar="IOT_SWARM_MQTT_ENDPOINT",
help="Endpoint of the MQTT receiving host.",
)
@click.option(
"--cert-path",
type=click.Path(exists=True),
required=True,
envvar="IOT_SWARM_MQTT_CERT_PATH",
help="Path to public key certificate for the device. Must match key assigned to the `--client-id` in the cloud provider.",
)
@click.option(
"--key-path",
type=click.Path(exists=True),
required=True,
envvar="IOT_SWARM_MQTT_KEY_PATH",
help="Path to the private key that pairs with the `--cert-path`.",
)
@click.option(
"--ca-cert-path",
type=click.Path(exists=True),
required=True,
envvar="IOT_SWARM_MQTT_CA_CERT_PATH",
help="Path to the root Certificate Authority (CA) for the MQTT host.",
)
@click.option(
"--sleep-time",
type=click.INT,
help="The number of seconds each site goes idle after sending a message.",
)
@click.option(
"--max-cycles",
type=click.IntRange(0),
help="Maximum number message sending cycles. Runs forever if set to 0.",
)
@click.option(
"--max-sites",
type=click.IntRange(0),
help="Maximum number of sites allowed to initialize. No limit if set to 0.",
)
@click.option(
"--swarm-name", type=click.STRING, help="Name given to swarm. Appears in the logs."
)
@click.option(
"--delay-start",
is_flag=True,
default=False,
help="Adds a random delay before the first message from each site up to `--sleep-time`.",
)
@click.option(
"--mqtt-prefix",
type=click.STRING,
help="Prefixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.",
)
@click.option(
"--mqtt-suffix",
type=click.STRING,
help="Suffixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.",
)
@common_device_options
@common_iotcore_options
@click.option("--dry", is_flag=True, default=False, help="Doesn't send out any data.")
@click.option("--device-type", type=click.Choice(["basic", "cr1000x"]), default="basic")
def mqtt(
ctx,
provider,
query,
endpoint,
cert_path,
Expand All @@ -259,7 +281,7 @@ def mqtt(
dry,
device_type,
):
"""Sends The cosmos data via MQTT protocol using PROVIDER.
"""Sends The cosmos data via MQTT protocol using IoT Core.
Data is collected from the db using QUERY and sent using CLIENT_ID.
Currently only supports sending through AWS IoT Core."""
Expand All @@ -280,7 +302,7 @@ async def _mqtt():

if dry == True:
connection = MockMessageConnection()
elif provider == "aws":
else:
connection = IotCoreMQTTConnection(
endpoint=endpoint,
cert_path=cert_path,
Expand Down Expand Up @@ -316,5 +338,113 @@ async def _mqtt():
asyncio.run(_mqtt())


@main.group()
@click.pass_context
@click.option(
"--site",
type=click.STRING,
multiple=True,
help="Adds a site to be initialized. Can be invoked multiple times for other sites."
" Grabs all sites from database query if none provided",
)
@click.option(
"--file",
type=click.Path(exists=True),
required=True,
envvar="IOT_SWARM_CSV_DB",
help="*.csv file used to instantiate a pandas database.",
)
def looping_csv(ctx, site, file):
"""Instantiates a pandas dataframe from a csv file which is used as the database.
Responsibility falls on the user to ensure the correct file is selected."""

ctx.obj["db"] = LoopingCsvDB(file)
ctx.obj["sites"] = site


looping_csv.add_command(test)


@looping_csv.command
@click.pass_context
@click.option("--max-sites", type=click.IntRange(min=0), default=0)
def list_sites(ctx, max_sites):
"""Prints the sites present in database."""

sites = ctx.obj["db"].query_site_ids(max_sites=max_sites)
click.echo(sites)


@looping_csv.command()
@click.pass_context
@common_device_options
@common_iotcore_options
@click.option("--dry", is_flag=True, default=False, help="Doesn't send out any data.")
def mqtt(
ctx,
endpoint,
cert_path,
key_path,
ca_cert_path,
client_id,
sleep_time,
max_cycles,
max_sites,
swarm_name,
delay_start,
mqtt_prefix,
mqtt_suffix,
dry,
device_type,
):
"""Sends The cosmos data via MQTT protocol using IoT Core.
Data is collected from the db using QUERY and sent using CLIENT_ID.
Currently only supports sending through AWS IoT Core."""

async def _mqtt():

sites = ctx.obj["sites"]
db = ctx.obj["db"]
if len(sites) == 0:
sites = db.query_site_ids(max_sites=max_sites)

if dry == True:
connection = MockMessageConnection()
else:
connection = IotCoreMQTTConnection(
endpoint=endpoint,
cert_path=cert_path,
key_path=key_path,
ca_cert_path=ca_cert_path,
client_id=client_id,
)

if device_type == "basic":
DeviceClass = BaseDevice
elif device_type == "cr1000x":
DeviceClass = CR1000XDevice

site_devices = [
DeviceClass(
site,
db,
connection,
sleep_time=sleep_time,
max_cycles=max_cycles,
delay_start=delay_start,
mqtt_prefix=mqtt_prefix,
mqtt_suffix=mqtt_suffix,
)
for site in sites
]

swarm = Swarm(site_devices, swarm_name)

await swarm.run()

asyncio.run(_mqtt())


if __name__ == "__main__":
main(auto_envvar_prefix="IOT_SWARM", obj={})
Loading

0 comments on commit b234463

Please sign in to comment.