Skip to content

Commit 0905ccc

Browse files
committed
implemented granular cache invalidation based on tags
1 parent 5679343 commit 0905ccc

File tree

5 files changed

+348
-7
lines changed

5 files changed

+348
-7
lines changed

examples/in_memory/main.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
# pyright: reportGeneralTypeIssues=false
2+
from collections import defaultdict
23
from contextlib import asynccontextmanager
3-
from typing import AsyncIterator, Dict, Optional
4+
from typing import Annotated, AsyncIterator, Dict, List, Optional
45

56
import pendulum
67
import uvicorn
7-
from fastapi import FastAPI
8+
from fastapi import Body, FastAPI, Query
89
from fastapi_cache import FastAPICache
910
from fastapi_cache.backends.inmemory import InMemoryBackend
10-
from fastapi_cache.decorator import cache
11+
from fastapi_cache.decorator import cache, cache_invalidator
12+
from fastapi_cache.tag_provider import TagProvider
1113
from pydantic import BaseModel
1214
from starlette.requests import Request
1315
from starlette.responses import JSONResponse, Response
@@ -65,7 +67,7 @@ async def get_kwargs(name: str):
6567

6668

6769
@app.get("/sync-me")
68-
@cache(namespace="test") # pyright: ignore[reportArgumentType]
70+
@cache(namespace="test") # pyright: ignore[reportArgumentType]
6971
def sync_me():
7072
# as per the fastapi docs, this sync function is wrapped in a thread,
7173
# thereby converted to async. fastapi-cache does the same.
@@ -115,8 +117,10 @@ async def uncached_put():
115117
put_ret = put_ret + 1
116118
return {"value": put_ret}
117119

120+
118121
put_ret2 = 0
119122

123+
120124
@app.get("/cached_put")
121125
@cache(namespace="test", expire=5)
122126
async def cached_put():
@@ -126,7 +130,7 @@ async def cached_put():
126130

127131

128132
@app.get("/namespaced_injection")
129-
@cache(namespace="test", expire=5, injected_dependency_namespace="monty_python") # pyright: ignore[reportArgumentType]
133+
@cache(namespace="test", expire=5, injected_dependency_namespace="monty_python") # pyright: ignore[reportArgumentType]
130134
def namespaced_injection(
131135
__fastapi_cache_request: int = 42, __fastapi_cache_response: int = 17
132136
) -> Dict[str, int]:
@@ -136,5 +140,65 @@ def namespaced_injection(
136140
}
137141

138142

143+
# Note: examples with cache invalidation
144+
files = defaultdict(
145+
list,
146+
{
147+
1: [1, 2, 3],
148+
2: [4, 5, 6],
149+
3: [100],
150+
},
151+
)
152+
153+
FileTagProvider = TagProvider("file")
154+
155+
156+
# Note: providing tags for future granular cache invalidation
157+
@app.get("/files")
158+
@cache(expire=10, tag_provider=FileTagProvider)
159+
async def get_files(file_id_in: Annotated[Optional[List[int]], Query()] = None):
160+
return [
161+
{"id": k, "value": v}
162+
for k, v in files.items()
163+
if (True if not file_id_in else k in file_id_in)
164+
]
165+
166+
167+
# Note: here we're retrieving keys by file_id, so we also need to invalidate this, when file changes
168+
@app.get("/files/{file_id:int}")
169+
@cache(
170+
expire=10,
171+
tag_provider=FileTagProvider,
172+
items_provider=lambda data, method_args, method_kwargs: [
173+
{"id": method_kwargs["file_id"]}
174+
],
175+
)
176+
async def get_file_keys(file_id: int):
177+
if file_id in files:
178+
return files[file_id]
179+
return Response("file id not found")
180+
181+
182+
# Note: here we can use default invalidator, because in response we have :id:
183+
@app.patch("/files/{file_id:int}")
184+
@cache_invalidator(tag_provider=FileTagProvider)
185+
async def edit_file(file_id: int, items: Annotated[List[int], Body(embed=True)]):
186+
files[file_id] = items
187+
return {
188+
"id": file_id,
189+
"value": files[file_id]
190+
}
191+
192+
193+
# Note: here we need to use custom :invalidator: because we don't have access to identifier in response
194+
@app.delete("/files/{file_id:int}")
195+
@cache_invalidator(
196+
tag_provider=FileTagProvider, invalidator=lambda resp, kwargs: kwargs["file_id"]
197+
)
198+
async def delete_file(file_id: int):
199+
if file_id in files:
200+
del files[file_id]
201+
202+
139203
if __name__ == "__main__":
140204
uvicorn.run("main:app", reload=True)

fastapi_cache/decorator.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929

3030
from fastapi_cache import FastAPICache
3131
from fastapi_cache.coder import Coder
32-
from fastapi_cache.types import KeyBuilder
32+
from fastapi_cache.tag_provider import TagProvider
33+
from fastapi_cache.types import ItemsProviderProtocol, KeyBuilder
3334

3435
logger: logging.Logger = logging.getLogger(__name__)
3536
logger.addHandler(logging.NullHandler())
@@ -90,6 +91,8 @@ def cache(
9091
key_builder: Optional[KeyBuilder] = None,
9192
namespace: str = "",
9293
injected_dependency_namespace: str = "__fastapi_cache",
94+
tag_provider: Optional[TagProvider] = None,
95+
items_provider: Optional[ItemsProviderProtocol] = None,
9396
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[Union[R, Response]]]]:
9497
"""
9598
cache all function
@@ -98,6 +101,8 @@ def cache(
98101
:param expire:
99102
:param coder:
100103
:param key_builder:
104+
:param tag_provider:
105+
:param items_provider:
101106
102107
:return:
103108
"""
@@ -194,6 +199,22 @@ async def ensure_async_func(*args: P.args, **kwargs: P.kwargs) -> R:
194199
f"Error setting cache key '{cache_key}' in backend:",
195200
exc_info=True,
196201
)
202+
else:
203+
if tag_provider:
204+
decoded = coder.decode(to_cache)
205+
try:
206+
await tag_provider.provide(
207+
data=decoded,
208+
parent_key=cache_key,
209+
expire=expire,
210+
items_provider=items_provider,
211+
method_args=args,
212+
method_kwargs=kwargs,
213+
)
214+
except Exception:
215+
logger.warning(
216+
f"Error while providing tags: {cache_key}", exc_info=True
217+
)
197218

198219
if response:
199220
response.headers.update(
@@ -229,3 +250,36 @@ async def ensure_async_func(*args: P.args, **kwargs: P.kwargs) -> R:
229250
return inner
230251

231252
return wrapper
253+
254+
255+
def default_invalidator(response: dict, kwargs: dict) -> str:
256+
return f"{response['id']}"
257+
258+
259+
def cache_invalidator(
260+
tag_provider: TagProvider,
261+
invalidator: Callable[[dict, dict], str] = default_invalidator,
262+
):
263+
def wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[Union[R, Response]]]:
264+
@wraps(func)
265+
async def inner(*args: P.args, **kwargs: P.kwargs) -> Union[R, Response]:
266+
coder = FastAPICache.get_coder()
267+
268+
response = await func(*args, **kwargs)
269+
270+
data = coder.decode(coder.encode(response))
271+
272+
try:
273+
object_id = invalidator(data, kwargs)
274+
await tag_provider.invalidate(object_id)
275+
except Exception as e:
276+
logger.warning(
277+
f"Exception occurred while invalidating cache: {e}",
278+
exc_info=True,
279+
)
280+
281+
return response
282+
283+
return inner
284+
285+
return wrapper

fastapi_cache/tag_provider.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import asyncio
2+
from typing import Callable, List, Optional, Union
3+
4+
from fastapi_cache import FastAPICache
5+
from fastapi_cache.types import ItemsProviderProtocol
6+
7+
8+
class TagProvider:
9+
def __init__(
10+
self,
11+
object_type: str,
12+
object_id_provider: Optional[Callable[[dict], str]] = None,
13+
) -> None:
14+
self.object_type = object_type
15+
self.object_id_provider = object_id_provider or self.default_object_id_provider
16+
17+
@staticmethod
18+
def default_object_id_provider(item: dict) -> str:
19+
return f"{item['id']}"
20+
21+
@staticmethod
22+
def default_items_provider(
23+
data: Union[dict, list],
24+
method_args: Optional[tuple] = None,
25+
method_kwargs: Optional[dict] = None,
26+
) -> list[dict]:
27+
return data
28+
29+
def get_tag(self, item: Optional[dict] = None, object_id: Optional[str] = None) -> str:
30+
prefix = FastAPICache.get_prefix()
31+
object_id = object_id or self.object_id_provider(item)
32+
return f"{prefix}:invalidation:{self.object_type}:{object_id}"
33+
34+
@staticmethod
35+
async def _append_value(key: str, parent_key: str, expire: int):
36+
backend = FastAPICache.get_backend()
37+
coder = FastAPICache.get_coder()
38+
value = await backend.get(key)
39+
if value:
40+
value = coder.decode(value)
41+
value.append(parent_key)
42+
else:
43+
value = [parent_key]
44+
await backend.set(key=key, value=coder.encode(value), expire=expire)
45+
46+
async def provide(
47+
self,
48+
data: Union[dict, list],
49+
parent_key: str,
50+
expire: Optional[int] = None,
51+
items_provider: Optional[ItemsProviderProtocol] = None,
52+
method_args: Optional[tuple] = None,
53+
method_kwargs: Optional[dict] = None,
54+
) -> None:
55+
"""
56+
Provides tags for endpoint.
57+
58+
:param data:
59+
:param parent_key:
60+
:param expire:
61+
:param items_provider:
62+
:param method_args:
63+
:param method_kwargs:
64+
"""
65+
provider = items_provider or self.default_items_provider
66+
tasks = [
67+
self._append_value(
68+
key=self.get_tag(item),
69+
parent_key=parent_key,
70+
expire=expire or FastAPICache.get_expire(),
71+
)
72+
for item in provider(data, method_args, method_kwargs)
73+
]
74+
await asyncio.gather(*tasks)
75+
76+
async def invalidate(self, object_id: str) -> None:
77+
"""
78+
Invalidate tags with given object_id
79+
80+
:param object_id: object_id to invalidate
81+
"""
82+
backend = FastAPICache.get_backend()
83+
coder = FastAPICache.get_coder()
84+
tag = self.get_tag(object_id=object_id)
85+
86+
value = await backend.get(tag)
87+
if not value:
88+
return
89+
90+
keys: List[str] = coder.decode(value)
91+
tasks = [backend.clear(key=key) for key in keys]
92+
tasks.append(backend.clear(key=tag))
93+
await asyncio.gather(*tasks)

fastapi_cache/types.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import abc
2-
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Union
2+
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, TypeAlias, Union
33

44
from starlette.requests import Request
55
from starlette.responses import Response
@@ -38,3 +38,23 @@ async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> Non
3838
@abc.abstractmethod
3939
async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
4040
raise NotImplementedError
41+
42+
43+
class _ItemsProviderProtocol(Protocol):
44+
def __call__(self, data: Union[dict, list]):
45+
pass
46+
47+
48+
class _ItemsProviderProtocolWithParams(Protocol):
49+
def __call__(
50+
self,
51+
data: Union[dict, list],
52+
method_args: Optional[tuple] = None,
53+
method_kwargs: Optional[dict] = None,
54+
) -> list[dict]:
55+
pass
56+
57+
58+
ItemsProviderProtocol: TypeAlias = (
59+
_ItemsProviderProtocol | _ItemsProviderProtocolWithParams
60+
)

0 commit comments

Comments
 (0)