diff --git a/src/aiogram_dialog/context/media_storage.py b/src/aiogram_dialog/context/media_storage.py index 6ba290da..3af4a84a 100644 --- a/src/aiogram_dialog/context/media_storage.py +++ b/src/aiogram_dialog/context/media_storage.py @@ -1,4 +1,5 @@ -from typing import Optional +import os +from typing import NamedTuple, Optional, cast from aiogram.types import ContentType from cachetools import LRUCache @@ -7,6 +8,11 @@ from aiogram_dialog.api.protocols import MediaIdStorageProtocol +class CachedMediaId(NamedTuple): + media_id: MediaId + mtime: Optional[float] + + class MediaIdStorage(MediaIdStorageProtocol): def __init__(self, maxsize=10240): self.cache = LRUCache(maxsize=maxsize) @@ -19,7 +25,25 @@ async def get_media_id( ) -> Optional[MediaId]: if not path and not url: return None - return self.cache.get((path, url, type)) + cached = cast( + Optional[CachedMediaId], + self.cache.get((path, url, type)), + ) + if cached is None: + return None + + if cached.mtime is not None: + mtime = self._get_file_mtime(path) + if mtime is not None and mtime != cached.mtime: + return None + return cached.media_id + + def _get_file_mtime(self, path: Optional[str]) -> Optional[float]: + if not path: + return None + if not os.path.exists(path): # noqa: PTH110 + return None + return os.path.getmtime(path) # noqa: PTH204 async def save_media_id( self, @@ -30,4 +54,7 @@ async def save_media_id( ) -> None: if not path and not url: return - self.cache[(path, url, type)] = media_id + self.cache[(path, url, type)] = CachedMediaId( + media_id, + self._get_file_mtime(path), + ) diff --git a/tests/widgets/media/test_media_storage.py b/tests/widgets/media/test_media_storage.py new file mode 100644 index 00000000..83b6ae0e --- /dev/null +++ b/tests/widgets/media/test_media_storage.py @@ -0,0 +1,50 @@ +import asyncio +import os +import tempfile + +import pytest +from aiogram.enums import ContentType + +from aiogram_dialog.context.media_storage import MediaIdStorage + + +@pytest.mark.asyncio +async def test_get_media_id(): + manager = MediaIdStorage() + with tempfile.TemporaryDirectory() as d: + filename = os.path.join(d, "file_test") # noqa: PTH118 + media_id = await manager.get_media_id( + filename, + None, + ContentType.DOCUMENT, + ) + assert media_id is None + + with open(filename, "w") as file: # noqa: PTH123 + file.write("test1") + + await manager.save_media_id( + filename, + None, + ContentType.DOCUMENT, + "test1", + ) + + media_id = await manager.get_media_id( + filename, + None, + ContentType.DOCUMENT, + ) + assert media_id == "test1" + + await asyncio.sleep(0.1) + + with open(filename, "w") as file: # noqa: PTH123 + file.write("test2") + + media_id = await manager.get_media_id( + filename, + None, + ContentType.DOCUMENT, + ) + assert media_id is None