Skip to content

Commit

Permalink
type hints -> fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dam2452 committed Jul 6, 2024
1 parent 52a5119 commit 2907bfc
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 39 deletions.
3 changes: 2 additions & 1 deletion Preprocessing/BOT_EXTRACTION_AUDIO.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import json
import os
import subprocess
from typing import Optional


def get_best_audio_stream(video_file) -> int or None:
def get_best_audio_stream(video_file) -> Optional[int]:
"""Zwraca indeks najlepszej ścieżki audio na podstawie bitrate."""
try:
cmd = f'ffprobe -v quiet -print_format json -show_streams -select_streams a "{video_file}"'
Expand Down
2 changes: 1 addition & 1 deletion bot/handlers/adjust_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ async def adjust_video_clip(message: types.Message, bot: Bot) -> None:
await DatabaseManager.log_system_message("ERROR", f"Error in adjust_video_clip for user '{message.from_user.username}': {e}")


def register_adjust_handler(dispatcher: Dispatcher):
def register_adjust_handler(dispatcher: Dispatcher) -> None:
dispatcher.include_router(router)


Expand Down
16 changes: 9 additions & 7 deletions bot/handlers/admin_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
)
from aiogram.filters import Command
from tabulate import tabulate
from typing import Optional

from bot.middlewares.auth_middleware import AuthorizationMiddleware
from bot.middlewares.error_middleware import ErrorHandlerMiddleware
Expand All @@ -15,6 +16,7 @@

logger = logging.getLogger(__name__)
router = Router()
dis = Dispatcher()


# Definicja UserManager dla łatwiejszego dostępu do funkcji zarządzania użytkownikami
Expand All @@ -32,31 +34,31 @@ async def update_user(username, is_admin=None, is_moderator=None, full_name=None
await DatabaseManager.update_user(username, is_admin, is_moderator, full_name, email, phone, subscription_end)

@staticmethod
async def get_all_users() -> list or None:
async def get_all_users() -> Optional[list]: # TO DO: Change return type
return await DatabaseManager.get_all_users()

@staticmethod
async def get_admin_users() -> list or None:
async def get_admin_users() -> Optional[list]: # TO DO: Change return type
return await DatabaseManager.get_admin_users()

@staticmethod
async def get_moderator_users() -> list or None:
async def get_moderator_users() -> Optional[list]: # TO DO: Change return type
return await DatabaseManager.get_moderator_users()

@staticmethod
async def add_subscription(username, days) -> str or None:
async def add_subscription(username, days) -> Optional[str]: # TO DO: Change return type
return await DatabaseManager.add_subscription(username, days)

@staticmethod
async def remove_subscription(username) -> None:
await DatabaseManager.remove_subscription(username)

@staticmethod
async def is_user_admin(username) -> bool or None:
async def is_user_admin(username) -> Optional[bool]: # TO DO: Change return type
return await DatabaseManager.is_user_admin(username)

@staticmethod
async def is_user_moderator(username) -> bool or None:
async def is_user_moderator(username) -> Optional[bool]: # TO DO: Change return type
return await DatabaseManager.is_user_moderator(username)


Expand Down Expand Up @@ -350,7 +352,7 @@ async def handle_transcription_request(message: types.Message) -> None:
logger.info(f"Searching transcription for quote: '{quote}'")
await DatabaseManager.log_user_activity(message.from_user.username, f"/transkrypcja {quote}")

search_transcriptions = SearchTranscriptions(router)
search_transcriptions = SearchTranscriptions(dis)
context_size = 15
result = await search_transcriptions.find_segment_with_context(quote, context_size)

Expand Down
5 changes: 3 additions & 2 deletions bot/handlers/episode_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
types,
)
from aiogram.filters import Command
from typing import Optional

from bot.middlewares.auth_middleware import AuthorizationMiddleware
from bot.middlewares.error_middleware import ErrorHandlerMiddleware
Expand All @@ -17,14 +18,14 @@
router = Router()
dis = Dispatcher()

def adjust_episode_number(absolute_episode) -> tuple[int, int] or None:
def adjust_episode_number(absolute_episode) -> Optional[tuple[int, int]]:
""" Adjust the absolute episode number to season and episode format """
season = (absolute_episode - 1) // 13 + 1
episode = (absolute_episode - 1) % 13 + 1
return season, episode


def split_message(message, max_length=4096) -> list[str] or str or None:
def split_message(message, max_length=4096) -> Optional[list[str]]:
""" Splits a message into chunks to fit within the Telegram message length limit """
parts = []
while len(message) > max_length:
Expand Down
5 changes: 3 additions & 2 deletions bot/handlers/manual_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
types,
)
from aiogram.filters import Command
from typing import Optional

from bot.middlewares.auth_middleware import AuthorizationMiddleware
from bot.middlewares.error_middleware import ErrorHandlerMiddleware
Expand All @@ -22,7 +23,7 @@
last_manual_clip = {} # Dictionary to store the last manual clip per chat ID


def minutes_str_to_seconds(time_str) -> float or None:
def minutes_str_to_seconds(time_str) -> Optional[float]:
""" Convert time string in the format MM:SS.ms to seconds """
try:
minutes, seconds = time_str.split(':')
Expand All @@ -33,7 +34,7 @@ def minutes_str_to_seconds(time_str) -> float or None:
return None


def adjust_episode_number(absolute_episode) -> tuple[int, int] or None:
def adjust_episode_number(absolute_episode) -> Optional[tuple[int, int]]:
""" Adjust the absolute episode number to season and episode format """
season = (absolute_episode - 1) // 13 + 1
episode = (absolute_episode - 1) % 13 + 1
Expand Down
3 changes: 2 additions & 1 deletion bot/handlers/subscription_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
types,
)
from aiogram.filters import Command
from typing import Optional

from bot.middlewares.auth_middleware import AuthorizationMiddleware
from bot.middlewares.error_middleware import ErrorHandlerMiddleware
Expand All @@ -19,7 +20,7 @@

class UserManager:
@staticmethod
async def get_subscription_status(username: str) -> tuple[None, None] or tuple[date, int] or None:
async def get_subscription_status(username: str) -> Optional[tuple[date, int]]: # TO DO: Change return type
subscription_end = await DatabaseManager.get_user_subscription(username)
if subscription_end is None:
return None
Expand Down
3 changes: 2 additions & 1 deletion bot/middlewares/auth_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

from aiogram import BaseMiddleware
from aiogram.types import Message
from typing import Awaitable

from bot.utils.database import DatabaseManager

logger = logging.getLogger(__name__)


class AuthorizationMiddleware(BaseMiddleware):
async def __call__(self, handler, event, data) -> None or bool:
async def __call__(self, handler, event, data) -> Awaitable: # TO DO: Change return type
if not isinstance(event, Message):
return await handler(event, data)

Expand Down
3 changes: 2 additions & 1 deletion bot/middlewares/error_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

from aiogram import BaseMiddleware
from aiogram.types import Message
from typing import Awaitable

from bot.utils.database import DatabaseManager

logger = logging.getLogger(__name__)


class ErrorHandlerMiddleware(BaseMiddleware):
async def __call__(self, handler, event, data) -> None or bool:
async def __call__(self, handler, event, data) -> Awaitable: # TO DO: Change return type
try:
return await handler(event, data)
except Exception as e:
Expand Down
35 changes: 20 additions & 15 deletions bot/utils/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
)

import asyncpg
from typing import (
# Coroutine,
# Any,
Optional,
)

from bot.settings import (
POSTGRES_DB,
Expand All @@ -16,7 +21,7 @@

class DatabaseManager:
@staticmethod
async def get_db_connection() -> asyncpg.Connection or None:
async def get_db_connection() -> Optional[asyncpg.Connection]: # TO DO: Change return type
return await asyncpg.connect(
host=POSTGRES_HOST,
port=POSTGRES_PORT,
Expand Down Expand Up @@ -168,7 +173,7 @@ async def remove_user(username) -> None:
await conn.close()

@staticmethod
async def get_all_users() -> list[asyncpg.Record] or None:
async def get_all_users() -> Optional[list[asyncpg.Record]]: # TO DO: Change return type
conn = await DatabaseManager.get_db_connection()
result = await conn.fetch(
'SELECT username, is_admin, is_moderator, full_name, email, phone, subscription_end FROM users',
Expand All @@ -177,21 +182,21 @@ async def get_all_users() -> list[asyncpg.Record] or None:
return result

@staticmethod
async def get_admin_users() -> list[asyncpg.Record] or None:
async def get_admin_users() -> Optional[list[asyncpg.Record]]: # TO DO: Change return type
conn = await DatabaseManager.get_db_connection()
result = await conn.fetch('SELECT username, full_name, email, phone FROM users WHERE is_admin = TRUE')
await conn.close()
return result

@staticmethod
async def get_moderator_users() -> list[asyncpg.Record] or None:
async def get_moderator_users() -> Optional[list[asyncpg.Record]]: # TO DO: Change return type
conn = await DatabaseManager.get_db_connection()
result = await conn.fetch('SELECT username, full_name, email, phone FROM users WHERE is_moderator = TRUE')
await conn.close()
return result

@staticmethod
async def is_user_authorized(username) -> bool or None:
async def is_user_authorized(username) -> bool:
conn = await DatabaseManager.get_db_connection()
result = await conn.fetchrow(
'SELECT is_admin, is_moderator, subscription_end FROM users WHERE username = $1',
Expand All @@ -207,14 +212,14 @@ async def is_user_authorized(username) -> bool or None:
return False

@staticmethod
async def is_user_admin(username) -> bool or None:
async def is_user_admin(username) -> Optional[bool]: # TO DO: Change return type
conn = await DatabaseManager.get_db_connection()
result = await conn.fetchval('SELECT is_admin FROM users WHERE username = $1', username)
await conn.close()
return result

@staticmethod
async def is_user_moderator(username) -> bool or None:
async def is_user_moderator(username) -> Optional[bool]: # TO DO: Change return type
conn = await DatabaseManager.get_db_connection()
result = await conn.fetchval('SELECT is_moderator FROM users WHERE username = $1', username)
await conn.close()
Expand All @@ -233,7 +238,7 @@ async def set_default_admin(admin_id) -> None:
await conn.close()

@staticmethod
async def get_saved_clips(username) -> list[asyncpg.Record] or None:
async def get_saved_clips(username) -> Optional[list[asyncpg.Record]]: # TO DO: Change return type
conn = await DatabaseManager.get_db_connection()
result = await conn.fetch(
'SELECT clip_name, start_time, end_time, season, episode_number, is_compilation FROM clips WHERE username = $1',
Expand All @@ -259,7 +264,7 @@ async def save_clip(
await conn.close()

@staticmethod
async def get_clip_by_name(username, clip_name) -> asyncpg.Record or None:
async def get_clip_by_name(username, clip_name) -> Optional[tuple[bytes, int, int]]: # TO DO: Change return type
conn = await DatabaseManager.get_db_connection()
result = await conn.fetchrow(
'''
Expand All @@ -272,7 +277,7 @@ async def get_clip_by_name(username, clip_name) -> asyncpg.Record or None:
return result

@staticmethod
async def get_clip_by_index(username, index) -> tuple[str, int, int, int, int, bool] or None:
async def get_clip_by_index(username, index) -> Optional[tuple[str, int, int, int, int, bool]]: # TO DO: Change return type
conn = await DatabaseManager.get_db_connection()
clip = await conn.fetchrow(
'''
Expand All @@ -291,7 +296,7 @@ async def get_clip_by_index(username, index) -> tuple[str, int, int, int, int, b
return None

@staticmethod
async def get_video_data_by_name(username, clip_name) -> bytes or None:
async def get_video_data_by_name(username, clip_name) -> Optional[bytes]: # TO DO: Change return type
conn = await DatabaseManager.get_db_connection()
result = await conn.fetchval(
'''
Expand All @@ -305,7 +310,7 @@ async def get_video_data_by_name(username, clip_name) -> bytes or None:
return result

@staticmethod
async def add_subscription(username, days) -> date or None:
async def add_subscription(username, days) -> Optional[date]: # TO DO: Change return type
conn = await DatabaseManager.get_db_connection()
new_end_date = await conn.fetchval(
'''
Expand All @@ -331,7 +336,7 @@ async def remove_subscription(username) -> None:
await conn.close()

@staticmethod
async def get_user_subscription(username) -> date or None:
async def get_user_subscription(username) -> Optional[date]: # TO DO: Change return type
conn = await DatabaseManager.get_db_connection()
subscription_end = await conn.fetchval('SELECT subscription_end FROM users WHERE username = $1', username)
await conn.close()
Expand All @@ -350,7 +355,7 @@ async def add_report(username, report) -> None:
await conn.close()

@staticmethod
async def delete_clip(username, clip_name) -> asyncpg.Record or None:
async def delete_clip(username, clip_name) -> Optional[asyncpg.Record]: # TO DO: Change return type
conn = await DatabaseManager.get_db_connection()
async with conn.transaction():
result = await conn.execute(
Expand All @@ -363,7 +368,7 @@ async def delete_clip(username, clip_name) -> asyncpg.Record or None:
return result

@staticmethod
async def is_clip_name_unique(chat_id: int, clip_name: str) -> bool or None:
async def is_clip_name_unique(chat_id: int, clip_name: str) -> bool:
conn = await DatabaseManager.get_db_connection()
result = await conn.fetchval(
'SELECT COUNT(*) FROM clips WHERE chat_id=$1 AND clip_name=$2',
Expand Down
3 changes: 2 additions & 1 deletion bot/utils/es_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
helpers,
)
import urllib3
from typing import Optional

from bot.settings import settings
from bot.utils.database import DatabaseManager
Expand All @@ -24,7 +25,7 @@
es_password = settings.ES_PASSWORD


async def connect_to_elasticsearch() -> AsyncElasticsearch or None:
async def connect_to_elasticsearch() -> Optional[AsyncElasticsearch]: # TO DO: Change return type
"""
Establishes a connection to Elasticsearch.
"""
Expand Down
9 changes: 5 additions & 4 deletions bot/utils/transcription_search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

from aiogram import Dispatcher
from typing import Optional

from bot.middlewares.auth_middleware import AuthorizationMiddleware
from bot.middlewares.error_middleware import ErrorHandlerMiddleware
Expand All @@ -18,7 +19,7 @@ def __init__(self, dispatcher: Dispatcher) -> None:
dispatcher.message.middleware(ErrorHandlerMiddleware())

@staticmethod
async def find_segment_by_quote(quote, season_filter=None, episode_filter=None, index='ranczo-transcriptions', return_all=False) -> list or dict or None:
async def find_segment_by_quote(quote, season_filter=None, episode_filter=None, index='ranczo-transcriptions', return_all=False) -> Optional[dict or list]:
"""
Searches for a segment by a given quote with optional season and episode filters.
Expand Down Expand Up @@ -113,7 +114,7 @@ async def find_segment_by_quote(quote, season_filter=None, episode_filter=None,
async def find_segment_with_context(
self, quote, context_size=30, season_filter=None, episode_filter=None,
index='ranczo-transcriptions',
) -> dict or None:
) -> Optional[dict]:
logger.info(
f"🔍 Searching for quote: '{quote}' with context size: {context_size}, filters - Season: {season_filter}, Episode: {episode_filter}",
)
Expand Down Expand Up @@ -210,7 +211,7 @@ async def find_segment_with_context(
return None

@staticmethod
async def find_video_path_by_episode(season, episode_number, index='ranczo-transcriptions') -> str or None:
async def find_video_path_by_episode(season, episode_number, index='ranczo-transcriptions') -> Optional[str]:
"""
Finds the video path for a given season and episode number.
Expand Down Expand Up @@ -273,7 +274,7 @@ async def find_video_path_by_episode(season, episode_number, index='ranczo-trans
return None

@staticmethod
async def find_episodes_by_season(season, index='ranczo-transcriptions') -> list or None:
async def find_episodes_by_season(season, index='ranczo-transcriptions') -> Optional[list]:
"""
Finds all episodes for a given season.
Expand Down
Loading

0 comments on commit 2907bfc

Please sign in to comment.