Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 63 additions & 62 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,22 @@
ConversationHandler, CallbackContext, ContextTypes)

from logger_config import logger
from constants import TIMEOUT_IN_SEC, STATION_SELECT_ONE_TIME, STATION_SELECT_SUBSCRIBE, ONE_TIME, SUBSCRIBE, UNSUBSCRIBE, VALID_SUMMARY_INTERVALS
from constants import TIMEOUT_IN_SEC, STATION_SELECT_ONE_TIME, STATION_SELECT_SUBSCRIBE, ONE_TIME, SUBSCRIBE, UNSUBSCRIBE, VALID_SUMMARY_INTERVALS, JOBQUEUE_DELAY, DEFAULT_USER_ID


class PlotBot:

def __init__(self, token, station_config, db=None, admin_id=None):
def __init__(self,
token,
station_config,
db=None,
admin_id=None,
ecmwf=None):

self._admin_id = admin_id
self.app = Application.builder().token(token).build()
self._db = db
self._ecmwf = ecmwf
self._station_names = sorted(
[station["name"] for station in station_config])
self._region_of_stations = {
Expand All @@ -24,14 +30,6 @@ def __init__(self, token, station_config, db=None, admin_id=None):
self._station_regions = sorted(
{station["region"]
for station in station_config})
self._subscriptions = {
station: set()
for station in self._station_names
}
self._one_time_forecast_requests = {
station: set()
for station in self._station_names
}
# filter for stations
self._filter_stations = filters.Regex("^(" +
"|".join(self._station_names) +
Expand Down Expand Up @@ -109,17 +107,53 @@ def __init__(self, token, station_config, db=None, admin_id=None):
self.app.add_handler(one_time_forecast_handler)
self.app.add_error_handler(self._error)

async def connect(self):
await self.app.initialize()
await self.app.updater.start_polling(allowed_updates=Update.ALL_TYPES)
await self.app.start()
logger.info('Bot connected')
self.app.job_queue.run_once(
self._override_basetime,
when=0,
name='Override basetime',
)
self.app.job_queue.run_repeating(
self._update_basetime,
interval=60,
first=60,
name='Update basetime',
)
self.app.job_queue.run_repeating(
self._cache_plots,
interval=30,
first=30,
name='Cache plots',
)
self.app.job_queue.run_repeating(
self._broadcast_from_queue,
interval=90,
first=60,
name='Broadcast',
)

while True:
await asyncio.sleep(1)
async def _override_basetime(self, context: CallbackContext):
self._ecmwf.override_base_time_from_init()

async def _update_basetime(self, context: CallbackContext):
self._ecmwf.upgrade_basetime_global()
self._ecmwf.upgrade_basetime_stations()

async def _send_plot_from_queue(self, context: CallbackContext):
job = context.job
user_id, station_name = job.data
plots = self._ecmwf.download_plots([station_name])
await self._send_plot_to_user(plots, station_name, user_id)

def start(self):
logger.info('Starting bot')
self.app.run_polling(allowed_updates=Update.ALL_TYPES)

async def _error(self, update: Update, context: CallbackContext):
user_id = update.message.chat_id

if update:
user_id = update.message.chat_id
else:
user_id = DEFAULT_USER_ID
logger.error(f"Exception while handling an update: {context.error}")
self._db.log_activity(
activity_type="bot-error",
Expand Down Expand Up @@ -303,9 +337,11 @@ async def _subscribe_for_station(self, update: Update,
reply_markup=ReplyKeyboardRemove(),
)
self._db.add_subscription(msg_text, user.id)
self._subscriptions[msg_text].add(user.id)

logger.info(f' {user.first_name} subscribed for Station {msg_text}')
context.job_queue.run_once(self._send_plot_from_queue,
JOBQUEUE_DELAY,
data=(user.id, msg_text))

self._db.log_activity(
activity_type="subscription",
Expand All @@ -324,8 +360,10 @@ async def _request_one_time_forecast_for_station(
reply_text,
reply_markup=ReplyKeyboardRemove(),
)
self._one_time_forecast_requests[msg_text].add(user.id)

context.job_queue.run_once(self._send_plot_from_queue,
JOBQUEUE_DELAY,
data=(user.id, msg_text))
logger.info(
f' {user.first_name} requested forecast for Station {msg_text}')

Expand All @@ -346,24 +384,8 @@ async def _cancel(self, update: Update, context: CallbackContext) -> int:

return ConversationHandler.END

def has_new_subscribers_waiting(self):
return any(users for users in self._subscriptions.values())

def has_one_time_forecast_waiting(self):
return any(users
for users in self._one_time_forecast_requests.values())

def stations_of_one_time_request(self):
return [
station
for station, users in self._one_time_forecast_requests.items()
if users
]

def stations_of_new_subscribers(self):
return [
station for station, users in self._subscriptions.items() if users
]
async def _cache_plots(self, context: CallbackContext):
self._ecmwf.cache_plots()

async def _send_plot_to_user(self, plots, station_name, user_id):
logger.debug(f'Send plot to user: {user_id}')
Expand All @@ -375,30 +397,9 @@ async def _send_plot_to_user(self, plots, station_name, user_id):
except Exception as e:
logger.error(f'Error sending plot to user {user_id}: {e}')

async def _send_plots(self, plots, requests):
for station_name, users in requests.items():
for user_id in users:
await self._send_plot_to_user(plots, station_name, user_id)

async def send_plots_to_new_subscribers(self, plots):
await self._send_plots(plots, self._subscriptions)
logger.info('plots sent to new subscribers')

self._subscriptions = {
station: set()
for station in self._station_names
}

async def send_one_time_forecast(self, plots):
await self._send_plots(plots, self._one_time_forecast_requests)
logger.info('plots sent to one time forecast requests')

self._one_time_forecast_requests = {
station: set()
for station in self._station_names
}

async def broadcast(self, plots):
async def _broadcast_from_queue(self, context: CallbackContext):
plots = self._ecmwf.download_latest_plots(
self._db.stations_with_subscribers())
if plots:
for station_name in plots:
for user_id in self._db.get_subscriptions_by_station(
Expand Down
4 changes: 4 additions & 0 deletions constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@
STATION_SELECT_ONE_TIME, STATION_SELECT_SUBSCRIBE, ONE_TIME, SUBSCRIBE, UNSUBSCRIBE = range(
5)
VALID_SUMMARY_INTERVALS = ['24 HOURS', '7 DAYS', '30 DAYS', '1 YEAR']

JOBQUEUE_DELAY = 10

DEFAULT_USER_ID = 999
67 changes: 10 additions & 57 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,14 @@
import logging
import argparse
import time
import yaml
import sys
import threading
import asyncio

from ecmwf import EcmwfApi
from bot import PlotBot
from logger_config import logger
from db import Database


async def await_func(func, *args):
async_func = asyncio.create_task(func(*args))
await async_func


def run_asyncio(func, *args):
asyncio.run(await_func(func, *args))


def run_asyncio_in_thread(func, name, *args):
thread = threading.Thread(target=run_asyncio,
name=name,
daemon=True,
args=[func, *args])
thread.start()
logging.debug(f'Started thread: {name}')


def start_bot(token, station_config, admin_id, db):
bot = PlotBot(token, station_config, admin_id=admin_id, db=db)
run_asyncio_in_thread(bot.connect, 'bot-connect')
return bot


def main():

parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -66,39 +39,19 @@ def main():
with open('stations.yaml', 'r') as file:
station_config = yaml.safe_load(file)

ecmwf = EcmwfApi(station_config)

db = Database('config.yml')

bot = start_bot(args.bot_token, station_config, args.admin_id, db)
bot = PlotBot(args.bot_token,
station_config,
admin_id=args.admin_id,
db=db,
ecmwf=ecmwf)
bot.start()

ecmwf = EcmwfApi(station_config)
ecmwf.override_base_time_from_init()

logger.info('Enter infinite loop')

while True:

try:
ecmwf.upgrade_basetime_global()
ecmwf.upgrade_basetime_stations()
if bot.has_new_subscribers_waiting():
run_asyncio_in_thread(
bot.send_plots_to_new_subscribers, 'new-subscribers',
ecmwf.download_plots(bot.stations_of_new_subscribers()))
if bot.has_one_time_forecast_waiting():
run_asyncio_in_thread(
bot.send_one_time_forecast, 'one-time-forecast',
ecmwf.download_plots(bot.stations_of_one_time_request()))
run_asyncio_in_thread(
bot.broadcast, 'broadcast',
ecmwf.download_latest_plots(db.stations_with_subscribers()))
ecmwf.cache_plots()
except Exception as e:
logger.error(f'An error occured: {e}')
sys.exit(1)

snooze = 5
logger.debug(f'snooze {snooze}s ...')
time.sleep(snooze)
# we should not be here
sys.exit(1)


Comment on lines +53 to 56
Copy link

Copilot AI Apr 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment '# we should not be here' together with sys.exit(1) could lead to confusion if bot.start() blocks indefinitely. Consider clarifying the intended control flow or removing unreachable code.

Copilot uses AI. Check for mistakes.
if __name__ == '__main__':
Expand Down
Loading