From 7a6ed85607477609ee3e5207c2fc87edd36e00b6 Mon Sep 17 00:00:00 2001 From: Jonas Jucker Date: Fri, 18 Apr 2025 15:30:39 +0200 Subject: [PATCH 1/3] use job queue --- bot.py | 119 ++++++++++++++++++++++----------------------------- constants.py | 2 + main.py | 62 +++------------------------ 3 files changed, 57 insertions(+), 126 deletions(-) diff --git a/bot.py b/bot.py index 40211ed..464a4ea 100644 --- a/bot.py +++ b/bot.py @@ -5,16 +5,17 @@ ConversationHandler, CallbackContext, ContextTypes) from logger_config import logger -from constants import TIMEOUT_IN_SEC, STATION_SELECT_ONE_TIME, STATION_SELECT_SUBSCRIBE, ONE_TIME, SUBSCRIBE, UNSUBSCRIBE, VALID_SUMMARY_INTERVALS +from constants import TIMEOUT_IN_SEC, STATION_SELECT_ONE_TIME, STATION_SELECT_SUBSCRIBE, ONE_TIME, SUBSCRIBE, UNSUBSCRIBE, VALID_SUMMARY_INTERVALS, JOBQUEUE_DELAY class PlotBot: - def __init__(self, token, station_config, db=None, admin_id=None): + def __init__(self, token, station_config, db=None, admin_id=None, ecmwf=None): self._admin_id = admin_id self.app = Application.builder().token(token).build() self._db = db + self._ecmwf = ecmwf self._station_names = sorted( [station["name"] for station in station_config]) self._region_of_stations = { @@ -24,14 +25,6 @@ def __init__(self, token, station_config, db=None, admin_id=None): self._station_regions = sorted( {station["region"] for station in station_config}) - self._subscriptions = { - station: set() - for station in self._station_names - } - self._one_time_forecast_requests = { - station: set() - for station in self._station_names - } # filter for stations self._filter_stations = filters.Regex("^(" + "|".join(self._station_names) + @@ -109,23 +102,49 @@ def __init__(self, token, station_config, db=None, admin_id=None): self.app.add_handler(one_time_forecast_handler) self.app.add_error_handler(self._error) - async def connect(self): - await self.app.initialize() - await self.app.updater.start_polling(allowed_updates=Update.ALL_TYPES) - await self.app.start() - logger.info('Bot connected') - - while True: - await asyncio.sleep(1) + self.app.job_queue.run_once(self._override_basetime, + when=0, + name='Override basetime',) + self.app.job_queue.run_repeating(self._update_basetime, + interval=60, + first=60, + name='Update basetime',) + self.app.job_queue.run_repeating(self._cache_plots, + interval=30, + first=30, + name='Cache plots',) + self.app.job_queue.run_repeating(self._broadcast_from_queue, + interval=90, + first=60, + name='Broadcast',) + + async def _override_basetime(self, context: CallbackContext): + logger.info('Overriding basetime') + self._ecmwf.override_base_time_from_init() + + async def _update_basetime(self, context: CallbackContext): + logger.info('Updating basetime') + self._ecmwf.upgrade_basetime_global() + self._ecmwf.upgrade_basetime_stations() + + async def _send_plot_from_queue(self, context: CallbackContext): + job = context.job + user_id, station_name = job.data + plots = self._ecmwf.download_plots([station_name]) + await self._send_plot_to_user(plots, station_name, user_id) + + def start(self): + logger.info('Starting bot') + self.app.run_polling(allowed_updates=Update.ALL_TYPES) async def _error(self, update: Update, context: CallbackContext): - user_id = update.message.chat_id + #user_id = update.message.chat_id logger.error(f"Exception while handling an update: {context.error}") - self._db.log_activity( - activity_type="bot-error", - user_id=user_id, - station="unknown", - ) + #self._db.log_activity( + # activity_type="bot-error", + # user_id=user_id, + # station="unknown", + #) async def _stats(self, update: Update, context: CallbackContext): user_id = update.message.chat_id @@ -303,9 +322,9 @@ async def _subscribe_for_station(self, update: Update, reply_markup=ReplyKeyboardRemove(), ) self._db.add_subscription(msg_text, user.id) - self._subscriptions[msg_text].add(user.id) logger.info(f' {user.first_name} subscribed for Station {msg_text}') + context.job_queue.run_once(self._send_plot_from_queue, JOBQUEUE_DELAY, data=(user.id, msg_text)) self._db.log_activity( activity_type="subscription", @@ -324,8 +343,8 @@ async def _request_one_time_forecast_for_station( reply_text, reply_markup=ReplyKeyboardRemove(), ) - self._one_time_forecast_requests[msg_text].add(user.id) + context.job_queue.run_once(self._send_plot_from_queue, JOBQUEUE_DELAY, data=(user.id, msg_text)) logger.info( f' {user.first_name} requested forecast for Station {msg_text}') @@ -346,24 +365,8 @@ async def _cancel(self, update: Update, context: CallbackContext) -> int: return ConversationHandler.END - def has_new_subscribers_waiting(self): - return any(users for users in self._subscriptions.values()) - - def has_one_time_forecast_waiting(self): - return any(users - for users in self._one_time_forecast_requests.values()) - - def stations_of_one_time_request(self): - return [ - station - for station, users in self._one_time_forecast_requests.items() - if users - ] - - def stations_of_new_subscribers(self): - return [ - station for station, users in self._subscriptions.items() if users - ] + async def _cache_plots(self, context: CallbackContext): + self._ecmwf.cache_plots() async def _send_plot_to_user(self, plots, station_name, user_id): logger.debug(f'Send plot to user: {user_id}') @@ -375,33 +378,11 @@ async def _send_plot_to_user(self, plots, station_name, user_id): except Exception as e: logger.error(f'Error sending plot to user {user_id}: {e}') - async def _send_plots(self, plots, requests): - for station_name, users in requests.items(): - for user_id in users: - await self._send_plot_to_user(plots, station_name, user_id) - - async def send_plots_to_new_subscribers(self, plots): - await self._send_plots(plots, self._subscriptions) - logger.info('plots sent to new subscribers') - - self._subscriptions = { - station: set() - for station in self._station_names - } - - async def send_one_time_forecast(self, plots): - await self._send_plots(plots, self._one_time_forecast_requests) - logger.info('plots sent to one time forecast requests') - - self._one_time_forecast_requests = { - station: set() - for station in self._station_names - } - - async def broadcast(self, plots): + async def _broadcast_from_queue(self, context: CallbackContext): + plots = self._ecmwf.download_latest_plots(self._db.stations_with_subscribers()) if plots: for station_name in plots: for user_id in self._db.get_subscriptions_by_station( station_name): await self._send_plot_to_user(plots, station_name, user_id) - logger.info('plots sent to all users') + logger.info('plots sent to all users') \ No newline at end of file diff --git a/constants.py b/constants.py index 3988c42..ebc0ad8 100644 --- a/constants.py +++ b/constants.py @@ -7,3 +7,5 @@ STATION_SELECT_ONE_TIME, STATION_SELECT_SUBSCRIBE, ONE_TIME, SUBSCRIBE, UNSUBSCRIBE = range( 5) VALID_SUMMARY_INTERVALS = ['24 HOURS', '7 DAYS', '30 DAYS', '1 YEAR'] + +JOBQUEUE_DELAY = 10 diff --git a/main.py b/main.py index 476412c..ce60dc4 100644 --- a/main.py +++ b/main.py @@ -1,41 +1,13 @@ import logging import argparse -import time import yaml import sys -import threading -import asyncio from ecmwf import EcmwfApi from bot import PlotBot from logger_config import logger from db import Database - -async def await_func(func, *args): - async_func = asyncio.create_task(func(*args)) - await async_func - - -def run_asyncio(func, *args): - asyncio.run(await_func(func, *args)) - - -def run_asyncio_in_thread(func, name, *args): - thread = threading.Thread(target=run_asyncio, - name=name, - daemon=True, - args=[func, *args]) - thread.start() - logging.debug(f'Started thread: {name}') - - -def start_bot(token, station_config, admin_id, db): - bot = PlotBot(token, station_config, admin_id=admin_id, db=db) - run_asyncio_in_thread(bot.connect, 'bot-connect') - return bot - - def main(): parser = argparse.ArgumentParser() @@ -66,39 +38,15 @@ def main(): with open('stations.yaml', 'r') as file: station_config = yaml.safe_load(file) - db = Database('config.yml') - - bot = start_bot(args.bot_token, station_config, args.admin_id, db) - ecmwf = EcmwfApi(station_config) - ecmwf.override_base_time_from_init() - - logger.info('Enter infinite loop') - while True: + db = Database('config.yml') - try: - ecmwf.upgrade_basetime_global() - ecmwf.upgrade_basetime_stations() - if bot.has_new_subscribers_waiting(): - run_asyncio_in_thread( - bot.send_plots_to_new_subscribers, 'new-subscribers', - ecmwf.download_plots(bot.stations_of_new_subscribers())) - if bot.has_one_time_forecast_waiting(): - run_asyncio_in_thread( - bot.send_one_time_forecast, 'one-time-forecast', - ecmwf.download_plots(bot.stations_of_one_time_request())) - run_asyncio_in_thread( - bot.broadcast, 'broadcast', - ecmwf.download_latest_plots(db.stations_with_subscribers())) - ecmwf.cache_plots() - except Exception as e: - logger.error(f'An error occured: {e}') - sys.exit(1) + bot = PlotBot(args.bot_token, station_config, admin_id=args.admin_id, db=db, ecmwf=ecmwf) + bot.start() - snooze = 5 - logger.debug(f'snooze {snooze}s ...') - time.sleep(snooze) + # we should not be here + sys.exit(1) if __name__ == '__main__': From 745200d3536be203985bfaf1ae37209e419d7fa9 Mon Sep 17 00:00:00 2001 From: github-actions Date: Fri, 18 Apr 2025 13:30:57 +0000 Subject: [PATCH 2/3] GitHub Action: Apply Pep8-formatting --- bot.py | 58 +++++++++++++++++++++++++++++++++++++-------------------- main.py | 7 ++++++- 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/bot.py b/bot.py index 464a4ea..dcfdda8 100644 --- a/bot.py +++ b/bot.py @@ -10,7 +10,12 @@ class PlotBot: - def __init__(self, token, station_config, db=None, admin_id=None, ecmwf=None): + def __init__(self, + token, + station_config, + db=None, + admin_id=None, + ecmwf=None): self._admin_id = admin_id self.app = Application.builder().token(token).build() @@ -102,21 +107,29 @@ def __init__(self, token, station_config, db=None, admin_id=None, ecmwf=None): self.app.add_handler(one_time_forecast_handler) self.app.add_error_handler(self._error) - self.app.job_queue.run_once(self._override_basetime, - when=0, - name='Override basetime',) - self.app.job_queue.run_repeating(self._update_basetime, - interval=60, - first=60, - name='Update basetime',) - self.app.job_queue.run_repeating(self._cache_plots, - interval=30, - first=30, - name='Cache plots',) - self.app.job_queue.run_repeating(self._broadcast_from_queue, - interval=90, - first=60, - name='Broadcast',) + self.app.job_queue.run_once( + self._override_basetime, + when=0, + name='Override basetime', + ) + self.app.job_queue.run_repeating( + self._update_basetime, + interval=60, + first=60, + name='Update basetime', + ) + self.app.job_queue.run_repeating( + self._cache_plots, + interval=30, + first=30, + name='Cache plots', + ) + self.app.job_queue.run_repeating( + self._broadcast_from_queue, + interval=90, + first=60, + name='Broadcast', + ) async def _override_basetime(self, context: CallbackContext): logger.info('Overriding basetime') @@ -324,7 +337,9 @@ async def _subscribe_for_station(self, update: Update, self._db.add_subscription(msg_text, user.id) logger.info(f' {user.first_name} subscribed for Station {msg_text}') - context.job_queue.run_once(self._send_plot_from_queue, JOBQUEUE_DELAY, data=(user.id, msg_text)) + context.job_queue.run_once(self._send_plot_from_queue, + JOBQUEUE_DELAY, + data=(user.id, msg_text)) self._db.log_activity( activity_type="subscription", @@ -344,7 +359,9 @@ async def _request_one_time_forecast_for_station( reply_markup=ReplyKeyboardRemove(), ) - context.job_queue.run_once(self._send_plot_from_queue, JOBQUEUE_DELAY, data=(user.id, msg_text)) + context.job_queue.run_once(self._send_plot_from_queue, + JOBQUEUE_DELAY, + data=(user.id, msg_text)) logger.info( f' {user.first_name} requested forecast for Station {msg_text}') @@ -379,10 +396,11 @@ async def _send_plot_to_user(self, plots, station_name, user_id): logger.error(f'Error sending plot to user {user_id}: {e}') async def _broadcast_from_queue(self, context: CallbackContext): - plots = self._ecmwf.download_latest_plots(self._db.stations_with_subscribers()) + plots = self._ecmwf.download_latest_plots( + self._db.stations_with_subscribers()) if plots: for station_name in plots: for user_id in self._db.get_subscriptions_by_station( station_name): await self._send_plot_to_user(plots, station_name, user_id) - logger.info('plots sent to all users') \ No newline at end of file + logger.info('plots sent to all users') diff --git a/main.py b/main.py index ce60dc4..06e0fa0 100644 --- a/main.py +++ b/main.py @@ -8,6 +8,7 @@ from logger_config import logger from db import Database + def main(): parser = argparse.ArgumentParser() @@ -42,7 +43,11 @@ def main(): db = Database('config.yml') - bot = PlotBot(args.bot_token, station_config, admin_id=args.admin_id, db=db, ecmwf=ecmwf) + bot = PlotBot(args.bot_token, + station_config, + admin_id=args.admin_id, + db=db, + ecmwf=ecmwf) bot.start() # we should not be here From c41348bf36cfa99f27fe78f6810373ac18a63bb9 Mon Sep 17 00:00:00 2001 From: Jonas Jucker Date: Fri, 18 Apr 2025 17:30:44 +0200 Subject: [PATCH 3/3] fix error --- bot.py | 20 +++++++++++--------- constants.py | 2 ++ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/bot.py b/bot.py index dcfdda8..50390e6 100644 --- a/bot.py +++ b/bot.py @@ -5,7 +5,7 @@ ConversationHandler, CallbackContext, ContextTypes) from logger_config import logger -from constants import TIMEOUT_IN_SEC, STATION_SELECT_ONE_TIME, STATION_SELECT_SUBSCRIBE, ONE_TIME, SUBSCRIBE, UNSUBSCRIBE, VALID_SUMMARY_INTERVALS, JOBQUEUE_DELAY +from constants import TIMEOUT_IN_SEC, STATION_SELECT_ONE_TIME, STATION_SELECT_SUBSCRIBE, ONE_TIME, SUBSCRIBE, UNSUBSCRIBE, VALID_SUMMARY_INTERVALS, JOBQUEUE_DELAY, DEFAULT_USER_ID class PlotBot: @@ -132,11 +132,9 @@ def __init__(self, ) async def _override_basetime(self, context: CallbackContext): - logger.info('Overriding basetime') self._ecmwf.override_base_time_from_init() async def _update_basetime(self, context: CallbackContext): - logger.info('Updating basetime') self._ecmwf.upgrade_basetime_global() self._ecmwf.upgrade_basetime_stations() @@ -151,13 +149,17 @@ def start(self): self.app.run_polling(allowed_updates=Update.ALL_TYPES) async def _error(self, update: Update, context: CallbackContext): - #user_id = update.message.chat_id + + if update: + user_id = update.message.chat_id + else: + user_id = DEFAULT_USER_ID logger.error(f"Exception while handling an update: {context.error}") - #self._db.log_activity( - # activity_type="bot-error", - # user_id=user_id, - # station="unknown", - #) + self._db.log_activity( + activity_type="bot-error", + user_id=user_id, + station="unknown", + ) async def _stats(self, update: Update, context: CallbackContext): user_id = update.message.chat_id diff --git a/constants.py b/constants.py index ebc0ad8..af9b9c0 100644 --- a/constants.py +++ b/constants.py @@ -9,3 +9,5 @@ VALID_SUMMARY_INTERVALS = ['24 HOURS', '7 DAYS', '30 DAYS', '1 YEAR'] JOBQUEUE_DELAY = 10 + +DEFAULT_USER_ID = 999