From 6a28055ea2ad9a7748899c4c692a045887aa0a59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Us=C3=A1n?= <5434104+andruten@users.noreply.github.com> Date: Fri, 31 Jan 2025 23:58:12 +0100 Subject: [PATCH] feat: Remove Sockets and clean some code (#85) * feat: Remove Sockets and clean some code * feat: reorganized command handlers and filters * chore: refactor * flake8 * typo --------- Co-authored-by: andruten --- README.md | 8 +-- backends.py | 60 ------------------- backends/__init__.py | 5 ++ backends/base_backend.py | 16 +++++ backends/request_backend.py | 33 ++++++++++ commands/add_service.py | 19 ++---- commands/check_all_services.py | 12 +++- command_handlers.py => commands/handlers.py | 18 +++--- commands/list_services.py | 2 +- commands/remove_service.py | 2 +- filters/__init__.py | 0 .../allowed_chats.py | 2 +- main.py | 8 ++- models/__init__.py | 6 ++ models/service.py | 48 +++++++++++++++ persistence/__init__.py | 7 +++ persistence/base_persistence.py | 30 ++++++++++ .../local_json_repository.py | 31 +--------- repositories/__init__.py | 5 ++ .../service_repository.py | 60 ++----------------- tests/test_backends.py | 36 +---------- tests/test_command_handlers.py | 60 ++++++++----------- tests/test_services.py | 25 ++++---- 23 files changed, 231 insertions(+), 262 deletions(-) delete mode 100644 backends.py create mode 100644 backends/__init__.py create mode 100644 backends/base_backend.py create mode 100644 backends/request_backend.py rename command_handlers.py => commands/handlers.py (85%) create mode 100644 filters/__init__.py rename filter_allowed_chats.py => filters/allowed_chats.py (92%) create mode 100644 models/__init__.py create mode 100644 models/service.py create mode 100644 persistence/__init__.py create mode 100644 persistence/base_persistence.py rename persistence.py => persistence/local_json_repository.py (76%) create mode 100644 repositories/__init__.py rename models.py => repositories/service_repository.py (55%) diff --git a/README.md b/README.md index 9112784..949d266 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Health check bot Welcome to Health check bot 👋!! I'm a python bot for telegram which intends to implement a (very basic) healthcheck -system. I can make HTTP requests or open Sockets. +system. I perform HTTP requests and I'll let you know if your service is healthy :). ## Available commands @@ -14,12 +14,10 @@ List all polling services ### add Add new service to the polling list ``` -/add +/add ``` -- `service_type` must be "socket" or "request" - `name` is the name that will be stored -- `domain` is the domain that will be reached by the service -- `port` is the port number +- `url` is the url that will be reached by the service ### remove Unsubscribe a service from the polling list by name diff --git a/backends.py b/backends.py deleted file mode 100644 index 0d08baf..0000000 --- a/backends.py +++ /dev/null @@ -1,60 +0,0 @@ -from abc import ABC, abstractmethod -import logging -from datetime import datetime -from socket import AF_INET, SOCK_STREAM, error, socket, timeout -from typing import Tuple, Optional - -import httpx -import ssl - -logger = logging.getLogger(__name__) - - -class BaseBackend(ABC): - def __init__(self, service) -> None: - self.service = service - - @abstractmethod - async def check(self, *args, **kwargs) -> Tuple[bool, Optional[float], Optional[datetime]]: # pragma: no cover - pass - - -class SocketBackend(BaseBackend): - def check(self) -> Tuple[bool, Optional[float], Optional[datetime]]: - a_socket = socket(AF_INET, SOCK_STREAM) - location = (self.service.domain, self.service.port) - try: - result_of_check = a_socket.connect_ex(location) - except (error, timeout): - return False, None, None - else: - a_socket.close() - return result_of_check == 0, None, None - - -class RequestBackend(BaseBackend): - def _get_url(self) -> str: - protocol = 'https' if self.service.port == 443 else 'http' - return f'{protocol}://{self.service.domain}' - - async def check(self, session) -> Tuple[bool, Optional[float], Optional[datetime]]: - url = self._get_url() - headers = { - 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ' - '(KHTML, like Gecko) Chrome/91.0.4472.114 Safari/537.36' - } - try: - logger.debug(f"Fetching {url}") - response = await session.request(method='GET', url=url, headers=headers) - except (httpx.HTTPError, ssl.SSLCertVerificationError, ) as exc: - logger.warning(f'"{url}" request failed {exc}') - return False, None, None - else: - raw_stream = response.extensions['network_stream'] - ssl_object = raw_stream.get_extra_info('ssl_object') - cert = ssl_object.getpeercert() - expire_date = datetime.strptime(cert['notAfter'], '%b %d %H:%M:%S %Y %Z') - elapsed_total_seconds = response.elapsed.total_seconds() - logger.debug(f'{url} fetched in {elapsed_total_seconds}') - service_is_healthy = (500 <= response.status_code <= 511) - return not service_is_healthy, elapsed_total_seconds, expire_date diff --git a/backends/__init__.py b/backends/__init__.py new file mode 100644 index 0000000..ca8fdcf --- /dev/null +++ b/backends/__init__.py @@ -0,0 +1,5 @@ +from .request_backend import RequestBackend + +__all__ = [ + 'RequestBackend', +] diff --git a/backends/base_backend.py b/backends/base_backend.py new file mode 100644 index 0000000..6cc512a --- /dev/null +++ b/backends/base_backend.py @@ -0,0 +1,16 @@ +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Tuple, Optional + + +class BaseBackend(ABC): + def __init__(self, service) -> None: + self.service = service + + @abstractmethod + async def check( + self, + *args, + **kwargs, + ) -> Tuple[bool, Optional[float], Optional[datetime], Optional[int]]: # pragma: no cover + pass diff --git a/backends/request_backend.py b/backends/request_backend.py new file mode 100644 index 0000000..38e0f18 --- /dev/null +++ b/backends/request_backend.py @@ -0,0 +1,33 @@ +import logging +from datetime import datetime +from typing import Tuple, Optional + +import httpx +import ssl + +from .base_backend import BaseBackend + +logger = logging.getLogger(__name__) + + +class RequestBackend(BaseBackend): + async def check(self, session) -> Tuple[bool, Optional[float], Optional[datetime], Optional[int]]: + headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ' + '(KHTML, like Gecko) Chrome/91.0.4472.114 Safari/537.36' + } + try: + logger.debug(f"Fetching {self.service.url}") + response = await session.request(method='GET', url=self.service.url, headers=headers) + except (httpx.HTTPError, ssl.SSLCertVerificationError,) as exc: + logger.warning(f'"{self.service.url}" request failed {exc}') + return False, None, None, None + else: + raw_stream = response.extensions['network_stream'] + ssl_object = raw_stream.get_extra_info('ssl_object') + cert = ssl_object.getpeercert() + expire_date = datetime.strptime(cert['notAfter'], '%b %d %H:%M:%S %Y %Z') + elapsed_total_seconds = response.elapsed.total_seconds() + logger.debug(f'{self.service.url} fetched in {elapsed_total_seconds}') + service_is_healthy = (400 <= response.status_code <= 511) + return not service_is_healthy, elapsed_total_seconds, expire_date, response.status_code diff --git a/commands/add_service.py b/commands/add_service.py index 9eedf6a..0e70ce4 100644 --- a/commands/add_service.py +++ b/commands/add_service.py @@ -2,8 +2,7 @@ from telegram import Update from telegram.ext import ContextTypes -from command_handlers import add_service_command_handler -from models import HEALTHCHECK_BACKENDS +from commands.handlers import add_service_command_handler logger = logging.getLogger(__name__) @@ -11,20 +10,12 @@ async def add_service(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: # Validate arguments - if len(context.args) != 4: - await update.message.reply_text('Please, use /add ') + if len(context.args) != 2: + await update.message.reply_text('Please, use /add ') return - service_type, name, domain, port = context.args - if service_type.lower() not in HEALTHCHECK_BACKENDS.keys(): - await update.message.reply_text(f' must be {", ".join(HEALTHCHECK_BACKENDS.keys())}') - return - try: - port = int(port) - except ValueError: - await update.message.reply_text(' must be a number') - return + name, url = context.args - service = add_service_command_handler(update.effective_chat.id, service_type, name, domain, port) + service = add_service_command_handler(update.effective_chat.id, name, url) await update.message.reply_text(f'ok! I\'ve added {service}') diff --git a/commands/check_all_services.py b/commands/check_all_services.py index 82146d8..0690b08 100644 --- a/commands/check_all_services.py +++ b/commands/check_all_services.py @@ -4,7 +4,7 @@ from telegram.constants import ParseMode from telegram.ext import ContextTypes -from command_handlers import chat_services_checker_command_handler +from commands.handlers import chat_services_checker_command_handler from models import Service logger = logging.getLogger(__name__) @@ -17,7 +17,10 @@ async def check_all_services(context: ContextTypes.DEFAULT_TYPE): fetched_services: Dict[str, List[Service]] = chat_fetched_services[chat_id] unhealthy_service: Service for unhealthy_service in fetched_services['unhealthy']: - text = f'{unhealthy_service.name} is down 🤕!' + text = ( + f'{unhealthy_service.name} is down 🤕! ' + f'\n `HTTP_STATUS_CODE={unhealthy_service.last_http_response_status_code}`' + ) await context.bot.send_message(chat_id=chat_id, text=text) healthy_service: Service for healthy_service in fetched_services['healthy']: @@ -27,5 +30,8 @@ async def check_all_services(context: ContextTypes.DEFAULT_TYPE): except (KeyError, TypeError) as e: logger.debug(f'Exception occurred: {e}') suffix = '' - text = f'{healthy_service.name} is fixed now{suffix} 😅!' + text = ( + f'{healthy_service.name} is fixed now{suffix} 😅!' + f'\n `HTTP_STATUS_CODE={healthy_service.last_http_response_status_code}`' + ) await context.bot.send_message(chat_id=chat_id, text=text, parse_mode=ParseMode.MARKDOWN) diff --git a/command_handlers.py b/commands/handlers.py similarity index 85% rename from command_handlers.py rename to commands/handlers.py index 4403945..17de38d 100644 --- a/command_handlers.py +++ b/commands/handlers.py @@ -5,7 +5,8 @@ import httpx -from models import Service, ServiceManager, ServiceStatus +from models import Service, ServiceStatus +from repositories import ServiceRepository from persistence import LocalJsonRepository logger = logging.getLogger(__name__) @@ -17,7 +18,7 @@ async def chat_service_checker_command_handler(chat_id: str) -> dict[str, dict[str, Any]]: persistence = LocalJsonRepository.create(chat_id) - service_manager = ServiceManager(persistence) + service_manager = ServiceRepository(persistence) active_services = service_manager.fetch_active() unhealthy_services = [] healthy_services = [] @@ -29,8 +30,9 @@ async def chat_service_checker_command_handler(chat_id: str) -> dict[str, dict[s backend_checks.append(service.healthcheck_backend.check(session)) responses = await asyncio.gather(*backend_checks) services = [] - for service, (service_is_healthy, time_to_first_byte, expire_date) in zip(active_services, responses): + for service, (service_is_healthy, time_to_first_byte, expire_date, http_status) in zip(active_services, responses): initial_service_status = service.status + service.last_http_response_status_code = http_status if service_is_healthy is False: service.status = ServiceStatus.UNHEALTHY if initial_service_status != ServiceStatus.UNHEALTHY: @@ -50,6 +52,7 @@ async def chat_service_checker_command_handler(chat_id: str) -> dict[str, dict[s service.last_time_healthy = now_utc service.time_to_first_byte = time_to_first_byte service.expire_date = expire_date + service.last_http_response_status_code = http_status services.append(service.to_dict()) service_manager.update(services) @@ -75,25 +78,26 @@ async def chat_services_checker_command_handler() -> dict[str, dict]: return all_chats_fetched_services -def add_service_command_handler(chat_id, service_type, name, domain, port) -> Service: +def add_service_command_handler(chat_id, name, url) -> Service: persistence = LocalJsonRepository.create(chat_id) - return ServiceManager(persistence).add(service_type, name, domain, port) + return ServiceRepository(persistence).add(name, url) def remove_services_command_handler(name, chat_id: str) -> None: persistence = LocalJsonRepository.create(chat_id) - ServiceManager(persistence).remove(name) + ServiceRepository(persistence).remove(name) def list_services_command_handler(chat_id: str) -> str: persistence = LocalJsonRepository.create(chat_id) - all_services = ServiceManager(persistence).fetch_all() + all_services = ServiceRepository(persistence).fetch_all() if not all_services: return 'There is nothing to see here' result = '' for service in all_services: result += '\n\n' result += f'`{service.name}` is {service.status.value.upper()}' + result += f'\nHTTP status code: `{service.last_http_response_status_code}`' if service.status == ServiceStatus.HEALTHY and service.time_to_first_byte is not None: result += f'\nttfb: `{service.time_to_first_byte}`' elif service.status == ServiceStatus.UNHEALTHY and service.last_time_healthy is not None: diff --git a/commands/list_services.py b/commands/list_services.py index dfc03e6..8c95832 100644 --- a/commands/list_services.py +++ b/commands/list_services.py @@ -2,7 +2,7 @@ from telegram.constants import ParseMode from telegram.ext import ContextTypes -from command_handlers import list_services_command_handler +from commands.handlers import list_services_command_handler async def list_services(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: diff --git a/commands/remove_service.py b/commands/remove_service.py index 3401b62..21e2907 100644 --- a/commands/remove_service.py +++ b/commands/remove_service.py @@ -1,7 +1,7 @@ from telegram import Update from telegram.ext import ContextTypes -from command_handlers import remove_services_command_handler +from commands.handlers import remove_services_command_handler async def remove_service(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: diff --git a/filters/__init__.py b/filters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/filter_allowed_chats.py b/filters/allowed_chats.py similarity index 92% rename from filter_allowed_chats.py rename to filters/allowed_chats.py index 8d2b219..da21c82 100644 --- a/filter_allowed_chats.py +++ b/filters/allowed_chats.py @@ -7,7 +7,7 @@ logger = logging.getLogger(__name__) -class FilterAllowedChats(MessageFilter): +class AllowedChatsMessageFilter(MessageFilter): def __init__(self, allowed_chat_ids: List[str]): super().__init__() diff --git a/main.py b/main.py index ce05e01..443c941 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,7 @@ from telegram.ext import ApplicationBuilder, CommandHandler from commands import check_all_services, list_services, remove_service, add_service, error -from filter_allowed_chats import FilterAllowedChats +from filters.allowed_chats import AllowedChatsMessageFilter abspath = os.path.abspath(__file__) directory_name = os.path.dirname(abspath) @@ -21,13 +21,17 @@ logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=LOG_LEVEL) +logging.getLogger('telegram').setLevel(logging.WARNING) +logging.getLogger('httpx').setLevel(logging.WARNING) +logging.getLogger('apscheduler').setLevel(logging.WARNING) + logger = logging.getLogger(__name__) def main() -> None: app = ApplicationBuilder().token(BOT_TOKEN).build() - filter_allowed_chats = FilterAllowedChats(ALLOWED_CHAT_IDS) + filter_allowed_chats = AllowedChatsMessageFilter(ALLOWED_CHAT_IDS) job_queue = app.job_queue job_queue.run_repeating(check_all_services, POLLING_INTERVAL, first=1) diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..ee25867 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,6 @@ +from .service import Service, ServiceStatus + +__all__ = [ + 'Service', + 'ServiceStatus', +] diff --git a/models/service.py b/models/service.py new file mode 100644 index 0000000..1c1fc3e --- /dev/null +++ b/models/service.py @@ -0,0 +1,48 @@ +import enum +from dataclasses import dataclass, field, asdict +from datetime import datetime +from typing import Optional, Dict + +from backends import RequestBackend + + +class ServiceStatus(enum.Enum): + UNKNOWN = 'unknown' + HEALTHY = 'healthy' + UNHEALTHY = 'unhealthy' + + +def service_asdict_factory(data): + def convert_value(obj): + if isinstance(obj, ServiceStatus): + return obj.value + elif isinstance(obj, datetime): + return obj.strftime('%Y-%m-%dT%H:%M:%S.%f') + return obj + + return dict((k, convert_value(v)) for k, v in data) + + +@dataclass +class Service: + name: str = field() + url: str = field() + enabled: bool = field(default=True) + last_time_healthy: Optional[datetime] = field(default=None) + last_http_response_status_code: Optional[int] = field(default=None) + time_to_first_byte: float = field(default=0.0) + status: ServiceStatus = field(init=True, default=ServiceStatus.UNKNOWN) + expire_date: Optional[datetime] = field(default=None) + + @property + def healthcheck_backend(self) -> RequestBackend: + return RequestBackend(self) + + def __repr__(self) -> str: # pragma: no cover + return f'{self.name} <{self.url}>' + + def __str__(self) -> str: # pragma: no cover + return f'{self.name} <{self.url}>' + + def to_dict(self) -> Dict: + return asdict(self, dict_factory=service_asdict_factory) diff --git a/persistence/__init__.py b/persistence/__init__.py new file mode 100644 index 0000000..f21cb75 --- /dev/null +++ b/persistence/__init__.py @@ -0,0 +1,7 @@ +from .base_persistence import BaseRepository +from .local_json_repository import LocalJsonRepository + +__all__ = [ + 'BaseRepository', + 'LocalJsonRepository', +] diff --git a/persistence/base_persistence.py b/persistence/base_persistence.py new file mode 100644 index 0000000..d7c4c82 --- /dev/null +++ b/persistence/base_persistence.py @@ -0,0 +1,30 @@ +from abc import ABC, abstractmethod +from typing import Dict, List + + +class BaseRepository(ABC): + + @classmethod + @abstractmethod + def create(cls, chat_id: str): # pragma: no cover + pass + + @abstractmethod + def fetch_all(self): # pragma: no cover + pass + + @abstractmethod + def add(self, data_to_append: Dict): # pragma: no cover + pass + + @abstractmethod + def remove(self, name: str): # pragma: no cover + pass + + @abstractmethod + def update(self, service_to_update: Dict): # pragma: no cover + pass + + @abstractmethod + def bulk_update(self, services_to_update: List[Dict]): # pragma: no cover + pass diff --git a/persistence.py b/persistence/local_json_repository.py similarity index 76% rename from persistence.py rename to persistence/local_json_repository.py index 2f26229..0011b8e 100644 --- a/persistence.py +++ b/persistence/local_json_repository.py @@ -1,40 +1,13 @@ import json import logging -from abc import ABC, abstractmethod from json import JSONDecodeError from os import listdir from os.path import isfile, join, splitext from typing import Dict, List -logger = logging.getLogger(__name__) - - -class BaseRepository(ABC): - - @classmethod - @abstractmethod - def create(cls, chat_id: str): # pragma: no cover - pass +from .base_persistence import BaseRepository - @abstractmethod - def fetch_all(self): # pragma: no cover - pass - - @abstractmethod - def add(self, data_to_append: Dict): # pragma: no cover - pass - - @abstractmethod - def remove(self, name: str): # pragma: no cover - pass - - @abstractmethod - def update(self, service_to_update: Dict): # pragma: no cover - pass - - @abstractmethod - def bulk_update(self, services_to_update: List[Dict]): # pragma: no cover - pass +logger = logging.getLogger(__name__) class LocalJsonRepository(BaseRepository): diff --git a/repositories/__init__.py b/repositories/__init__.py new file mode 100644 index 0000000..c25e7d6 --- /dev/null +++ b/repositories/__init__.py @@ -0,0 +1,5 @@ +from .service_repository import ServiceRepository + +__all__ = [ + 'ServiceRepository', +] diff --git a/models.py b/repositories/service_repository.py similarity index 55% rename from models.py rename to repositories/service_repository.py index fc4a6d7..6b1b88c 100644 --- a/models.py +++ b/repositories/service_repository.py @@ -1,64 +1,14 @@ -import enum -from dataclasses import asdict, dataclass, field import logging from datetime import datetime, timezone -from typing import Dict, List, Optional +from typing import Dict, List -from backends import BaseBackend, RequestBackend, SocketBackend +from models.service import ServiceStatus, Service from persistence import BaseRepository logger = logging.getLogger(__name__) -HEALTHCHECK_BACKENDS = { - 'socket': SocketBackend, - 'request': RequestBackend, -} - -class ServiceStatus(enum.Enum): - UNKNOWN = 'unknown' - HEALTHY = 'healthy' - UNHEALTHY = 'unhealthy' - - -def service_asdict_factory(data): - def convert_value(obj): - if isinstance(obj, ServiceStatus): - return obj.value - elif isinstance(obj, datetime): - return obj.strftime('%Y-%m-%dT%H:%M:%S.%f') - return obj - - return dict((k, convert_value(v)) for k, v in data) - - -@dataclass -class Service: - service_type: str = field() - name: str = field() - domain: str = field() - port: int = field() - enabled: bool = field(default=True) - last_time_healthy: Optional[datetime] = field(default=None) - time_to_first_byte: float = field(default=0.0) - status: ServiceStatus = field(init=True, default=ServiceStatus.UNKNOWN) - expire_date: Optional[datetime] = field(default=None) - - @property - def healthcheck_backend(self) -> BaseBackend: - return HEALTHCHECK_BACKENDS[self.service_type](self) - - def __repr__(self) -> str: # pragma: no cover - return f'{self.name} <{self.domain}:{self.port}>' - - def __str__(self) -> str: # pragma: no cover - return f'{self.name} <{self.domain}>' - - def to_dict(self) -> Dict: - return asdict(self, dict_factory=service_asdict_factory) - - -class ServiceManager: +class ServiceRepository: def __init__(self, persistence_backend: BaseRepository) -> None: self.persistence_backend = persistence_backend @@ -104,8 +54,8 @@ def fetch_all(self) -> List[Service]: def fetch_active(self) -> List[Service]: return [service for service in self.fetch_all() if service.enabled is True] - def add(self, service_type: str, name: str, domain: str, port: int, enabled: bool = True) -> Service: - service = Service(service_type.lower(), name, domain, int(port), enabled) + def add(self, name: str, url: str) -> Service: + service = Service(name, url, ) self.persistence_backend.add(service.to_dict()) return service diff --git a/tests/test_backends.py b/tests/test_backends.py index 4d4f551..9777e1c 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -1,43 +1,9 @@ import unittest -from socket import error, timeout from unittest.mock import MagicMock, patch import httpx -from backends import RequestBackend, SocketBackend - - -@patch('backends.socket') -class TestSocketBackend(unittest.TestCase): - - def setUp(self) -> None: - super().setUp() - service = MagicMock(domain='fake', port=456) - self.backend = SocketBackend(service) - - def test_success(self, mock_socket): - mock_socket.return_value.connect_ex.return_value = 0 - - is_healthy, time_to_first_byte, expire_date = self.backend.check() - self.assertTrue(is_healthy) - - def test_error(self, mock_socket): - mock_socket.return_value.connect_ex.return_value = 1 - - is_healthy, time_to_first_byte, expire_date = self.backend.check() - self.assertFalse(is_healthy) - - def test_error_exception(self, mock_socket): - mock_socket.return_value.connect_ex.side_effect = error - - is_healthy, time_to_first_byte, expire_date = self.backend.check() - self.assertFalse(is_healthy) - - def test_timeout_exception(self, mock_socket): - mock_socket.return_value.connect_ex.side_effect = timeout - - is_healthy, time_to_first_byte, expire_date = self.backend.check() - self.assertFalse(is_healthy) +from backends import RequestBackend class TestRequestBackend(unittest.TestCase): diff --git a/tests/test_command_handlers.py b/tests/test_command_handlers.py index 0c4e504..a38702b 100644 --- a/tests/test_command_handlers.py +++ b/tests/test_command_handlers.py @@ -1,9 +1,9 @@ import unittest from unittest.mock import MagicMock, PropertyMock, patch -from command_handlers import (add_service_command_handler, chat_service_checker_command_handler, - chat_services_checker_command_handler, - list_services_command_handler, remove_services_command_handler) +from commands.handlers import (add_service_command_handler, chat_service_checker_command_handler, + chat_services_checker_command_handler, + list_services_command_handler, remove_services_command_handler) from models import Service, ServiceStatus @@ -16,8 +16,8 @@ def mock_chat_handler_side_effect(*args, **kwargs): class TestCommandHandlers(unittest.TestCase): - @patch('command_handlers.LocalJsonRepository.create') - @patch('command_handlers.ServiceManager.fetch_active') + @patch('commands.handlers.LocalJsonRepository.create') + @patch('commands.handlers.ServiceRepository.fetch_active') async def test_chat_service_checker_command_handler( self, mock_service_manager, @@ -29,24 +29,18 @@ async def test_chat_service_checker_command_handler( mock_healthcheck_backend.return_value = MagicMock(check=mock_ht_service) mock_service_manager.return_value = [ Service( - service_type='request', name='test', - domain='test.com', - port=443, + url='test.com', status=ServiceStatus.HEALTHY, ), Service( - service_type='request', name='test2', - domain='test2.com', - port=443, + url='test2.com', status=ServiceStatus.UNHEALTHY, ), Service( - service_type='socket', name='test3', - domain='test3', - port=4442, + url='test3', status=ServiceStatus.UNKNOWN, ), ] @@ -56,8 +50,8 @@ async def test_chat_service_checker_command_handler( self.assertIsInstance(chat_services, dict) self.assertTrue(len(chat_services), 1) - @patch('command_handlers.LocalJsonRepository.create') - @patch('command_handlers.ServiceManager.fetch_active') + @patch('commands.handlers.LocalJsonRepository.create') + @patch('commands.handlers.ServiceRepository.fetch_active') async def test_chat_service_checker_command_handler_empty( self, mock_service_manager, @@ -74,8 +68,8 @@ async def test_chat_service_checker_command_handler_empty( self.assertIsInstance(chat_services, dict) self.assertEqual(len(chat_services), 0) - @patch('command_handlers.LocalJsonRepository.create') - @patch('command_handlers.ServiceManager.fetch_active') + @patch('commands.handlers.LocalJsonRepository.create') + @patch('commands.handlers.ServiceRepository.fetch_active') async def test_chat_service_checker_command_handler_unhealthy( self, mock_service_manager, @@ -87,10 +81,8 @@ async def test_chat_service_checker_command_handler_unhealthy( mock_healthcheck_backend.return_value = MagicMock(check=mock_ht_service) mock_service_manager.return_value = [ Service( - service_type='request', name='test', - domain='test.com', - port=443, + url='test.com', status=ServiceStatus.HEALTHY, ), ] @@ -100,8 +92,8 @@ async def test_chat_service_checker_command_handler_unhealthy( self.assertIsInstance(chat_services, dict) self.assertTrue(len(chat_services), 1) - @patch('command_handlers.LocalJsonRepository.get_all_chat_ids') - @patch('command_handlers.chat_service_checker_command_handler') + @patch('commands.handlers.LocalJsonRepository.get_all_chat_ids') + @patch('commands.handlers.chat_service_checker_command_handler') async def test_chat_services_checker_command_handler(self, mock_chat_handler, mock_chat_ids): mock_chat_handler.side_effect = mock_chat_handler_side_effect mock_chat_ids.return_value = ['1234', '5678'] @@ -111,20 +103,20 @@ async def test_chat_services_checker_command_handler(self, mock_chat_handler, mo self.assertIsInstance(chat_failing_services, dict) self.assertTrue(len(chat_failing_services), 2) - @patch('command_handlers.LocalJsonRepository.create') - @patch('command_handlers.ServiceManager.add') + @patch('commands.handlers.LocalJsonRepository.create') + @patch('commands.handlers.ServiceRepository.add') def test_add_service_command_handler(self, mock_service_manager_add, mock_repository_create): - mock_service_manager_add.return_value = Service('request', 'test', 'test.com', 443) + mock_service_manager_add.return_value = Service('test', 'test.com') mock_repository_create.return_value = MagicMock() - service = add_service_command_handler('1234', 'request', 'test', 'test.com', 443) + service = add_service_command_handler('1234', 'test', 'test.com') self.assertIsInstance(service, Service) mock_service_manager_add.assert_called_once() mock_repository_create.assert_called_once() - @patch('command_handlers.ServiceManager.remove') - @patch('command_handlers.LocalJsonRepository.create') + @patch('commands.handlers.ServiceRepository.remove') + @patch('commands.handlers.LocalJsonRepository.create') def test_remove_services_command_handler(self, mock_service_manager_remove, mock_repository_create): mock_service_manager_remove.return_value = True mock_repository_create.return_value = MagicMock() @@ -134,13 +126,13 @@ def test_remove_services_command_handler(self, mock_service_manager_remove, mock mock_service_manager_remove.assert_called_once() mock_repository_create.assert_called_once() - @patch('command_handlers.LocalJsonRepository.create') - @patch('command_handlers.ServiceManager.fetch_all') + @patch('commands.handlers.LocalJsonRepository.create') + @patch('commands.handlers.ServiceRepository.fetch_all') def test_list_services_command_handler(self, mock_fetch_all, mock_repository_create): mock_fetch_all.return_value = [ - Service(service_type='request', name='test', domain='test.com', port=443, status=ServiceStatus.HEALTHY), - Service(service_type='request', name='test2', domain='test2.com', port=443, status=ServiceStatus.UNHEALTHY), - Service(service_type='socket', name='test3', domain='test3', port=4442, status=ServiceStatus.UNKNOWN), + Service(name='test', url='test.com', status=ServiceStatus.HEALTHY), + Service(name='test2', url='test2.com', status=ServiceStatus.UNHEALTHY), + Service(name='test3', url='test3', status=ServiceStatus.UNKNOWN), ] mock_repository_create.return_value = MagicMock() services = list_services_command_handler('1234') diff --git a/tests/test_services.py b/tests/test_services.py index d3738c8..799db18 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -1,7 +1,8 @@ import unittest from unittest.mock import MagicMock -from models import ServiceManager, ServiceStatus, Service +from models import ServiceStatus, Service +from repositories import ServiceRepository class TestServices(unittest.TestCase): @@ -9,7 +10,7 @@ class TestServices(unittest.TestCase): def setUp(self) -> None: super().setUp() mock_persistence_backend = MagicMock() - self.service_manager = ServiceManager(mock_persistence_backend) + self.service_manager = ServiceRepository(mock_persistence_backend) def test_mark_as_healthy(self): mock_service = MagicMock() @@ -23,30 +24,24 @@ def test_mark_as_unhealthy(self): def test_fetch_all(self): self.service_manager.persistence_backend.fetch_all.return_value = [ - {"service_type": "request", "name": "test1", "domain": "test1.com", "port": 443, "enabled": True, - "status": "unknown"}, - {"service_type": "request", "name": "test2", "domain": "test2.com", "port": 443, "enabled": True, - "status": "healthy"}, - {"service_type": "request", "name": "test3", "domain": "test3.com", "port": 443, "enabled": True, - "status": "unhealthy"}, + {"name": "test1", "url": "test1.com", "status": "unknown"}, + {"name": "test2", "url": "test2.com", "status": "healthy"}, + {"name": "test3", "url": "test3.com", "status": "unhealthy"}, ] services = self.service_manager.fetch_all() self.assertTrue(all([isinstance(service, Service) for service in services])) def test_fetch_active(self): self.service_manager.persistence_backend.fetch_all.return_value = [ - {"service_type": "request", "name": "test1", "domain": "test1.com", "port": 443, "enabled": True, - "status": "unknown"}, - {"service_type": "request", "name": "test2", "domain": "test2.com", "port": 443, "enabled": True, - "status": "healthy"}, - {"service_type": "request", "name": "test3", "domain": "test3.com", "port": 443, "enabled": False, - "status": "unhealthy"}, + {"name": "test1", "url": "test1.com", "status": "unknown"}, + {"name": "test2", "url": "test2.com", "status": "healthy"}, + {"name": "test3", "url": "test3.com", "status": "unhealthy"}, ] services = self.service_manager.fetch_active() self.assertTrue(all([isinstance(service, Service) and service.enabled is True for service in services])) def test_add(self): - self.service_manager.add('request', 'test', 'test.com', 443) + self.service_manager.add('test', 'test.com') def test_remove(self): self.service_manager.remove('test1')