Skip to content

Commit bcdc3dc

Browse files
authored
Merge pull request #2 from Minibrams/feature/function-wrappers
Added support for overriding endpoint functions while using add_dependencies()
2 parents 100ff13 + 245ceae commit bcdc3dc

File tree

5 files changed

+173
-35
lines changed

5 files changed

+173
-35
lines changed

README.md

Lines changed: 85 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,41 @@ Create decorators that leverage FastAPI's `Depends()` and built-in dependencies,
1111
pip install fastapi-decorators
1212
```
1313

14+
# TL;DR
15+
The library supplies the `add_dependencies()` decorator function which effectively allows you to add argument dependencies to your FastAPI endpoints.
16+
17+
For example, the following three endpoints have the same signature:
18+
```python
19+
# Using normal dependencies
20+
@app.get("/items/{item_id}")
21+
def read_item(item_id: int, _ = Depends(get_current_user)):
22+
...
23+
24+
# Using add_dependency directly
25+
@app.get("/items/{item_id}")
26+
@add_dependencies(Depends(get_current_user))
27+
def read_item(item_id: int):
28+
...
29+
30+
# Using a custom decorator
31+
def authorize():
32+
def dependency(user = Depends(get_current_user)):
33+
return user
34+
return add_dependencies(Depends(dependency))
35+
36+
@app.get("/items/{item_id}")
37+
@authorize()
38+
def read_item(item_id: int):
39+
...
40+
```
41+
1442
# Usage examples
1543

1644
- [Logging decorator](#logging-decorator)
1745
- [Authorization decorator](#authorization-decorator)
1846
- [Custom Response Header decorator](#custom-response-header-decorator)
1947
- [Rate Limiting decorator](#rate-limiting-decorator)
48+
- [Caching decorator](#caching-decorator)
2049
- [Error Handling decorator](#error-handling-decorator)
2150
- [Combining Multiple decorators](#combining-multiple-decorators)
2251
- [Using `add_dependencies()` directly](#using-add_dependencies-directly)
@@ -126,26 +155,70 @@ def limited_endpoint():
126155

127156
```
128157

158+
## Caching decorator
159+
Add caching to your endpoints:
160+
161+
```python
162+
def get_cache() -> dict:
163+
return {} # Use a real cache like Redis or Memcached
164+
165+
def cache_response(max_age: int = 5):
166+
def decorator(func):
167+
168+
# Wrap the endpoint after adding the get_cache dependency
169+
@add_dependencies(cache=Depends(get_cache))
170+
@wraps(func)
171+
def wrapper(*args, cache: dict, **kwargs):
172+
key = func.__name__
173+
174+
if key in cache:
175+
timestamp, data = cache[key]
176+
if time() - timestamp < max_age:
177+
# Cache hit
178+
return data
179+
180+
# Cache miss - call the endpoint as usual
181+
result = func(*args, **kwargs)
182+
183+
# Store the result in the cache
184+
cache[key] = time(), result
185+
return result
186+
187+
return wrapper
188+
return decorator
189+
190+
@app.get("/cached-data")
191+
@cache_response(max_age=10)
192+
def get_cached_data():
193+
...
194+
```
195+
129196
## Error Handling decorator
130197
Create a decorator to handle exceptions and return custom responses:
131198

132199
```python
133200
from fastapi_decorators import add_dependencies
134201
from fastapi import Depends, Response
135-
import traceback
136-
137-
def handle_errors():
138-
async def dependency(response: Response):
139-
try:
140-
yield
141-
except Exception as e:
142-
response.status_code = 500
143-
response.content = f"An error occurred: {str(e)}"
144202

145-
# Optionally print the traceback
146-
traceback.print_exc()
203+
def get_crash_log_storage() -> list:
204+
return [] # Use a real storage like a database
147205

148-
return add_dependencies(Depends(dependency))
206+
def handle_errors():
207+
def decorator(func):
208+
209+
# Wrap the endpoint after adding the crash_logs dependency
210+
@add_dependencies(crash_logs = Depends(get_crash_log_storage))
211+
@wraps(func)
212+
def wrapper(*args, crash_logs: list, **kwargs):
213+
try:
214+
return func(*args, **kwargs)
215+
except Exception as e:
216+
# Log the error and return a custom response
217+
crash_logs.append({ 'error': str(e), 'function': func.__name__ })
218+
return JSONResponse(status_code=500, content={ "detail": str(e) })
219+
220+
return wrapper
221+
return decorator
149222

150223
@app.get("/may-fail")
151224
@handle_errors()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setuptools.setup(
77
name="fastapi-decorators",
8-
version="1.0.2",
8+
version="1.0.3",
99
author="Anders Brams",
1010
author_email="anders@brams.dk",
1111
description="Create decorators that leverage FastAPI's `Depends()` and built-in dependencies, enabling you to inject dependencies directly into your decorators.",

src/fastapi_decorators/decorators.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from functools import wraps
55
from inspect import Parameter, signature
66
from types import MappingProxyType
7-
from typing import Any, Callable, Dict, Tuple, TypeVar, cast
7+
from typing import Any, Callable, Tuple, TypeVar, cast
88

99
F = TypeVar('F', bound=Callable[..., Any])
1010

11-
def add_dependencies(*dependencies: Any) -> Callable[[F], F]:
11+
def add_dependencies(*args: Any, **kwargs: Any) -> Callable[[F], F]:
1212
"""
1313
Decorator to add dependencies to a function without exposing them as arguments.
1414
@@ -50,10 +50,10 @@ def decorator(func: F) -> F:
5050
original_signature = signature(func)
5151
original_parameters = original_signature.parameters
5252

53-
new_parameters = _add_dependency_parameters(dependencies, original_parameters)
53+
new_parameters = _add_dependency_parameters(args, kwargs, original_parameters)
5454
new_signature = original_signature.replace(parameters=tuple(new_parameters.values()))
5555

56-
wrapper = _create_wrapper(func, original_parameters)
56+
wrapper = _create_wrapper(func, new_parameters)
5757
wrapper.__signature__ = new_signature # type: ignore
5858

5959
return cast(F, wrapper)
@@ -62,8 +62,9 @@ def decorator(func: F) -> F:
6262

6363
def _add_dependency_parameters(
6464
dependencies: Tuple[Any, ...],
65+
named_dependencies: dict[str, Any],
6566
original_parameters: MappingProxyType[str, Parameter],
66-
) -> Dict[str, Parameter]:
67+
) -> dict[str, Parameter]:
6768
"""
6869
Adds dependency parameters to the function's parameters.
6970
@@ -84,12 +85,19 @@ def _add_dependency_parameters(
8485
annotation=Any,
8586
)
8687

87-
return new_parameters
88+
for name, dependency in named_dependencies.items():
89+
new_parameters[name] = Parameter(
90+
name,
91+
kind=Parameter.KEYWORD_ONLY,
92+
default=dependency,
93+
annotation=Any,
94+
)
8895

96+
return new_parameters
8997

9098
def _generate_dependency_name(
9199
index: int,
92-
current_parameters: Dict[str, Parameter],
100+
current_parameters: dict[str, Parameter],
93101
) -> str:
94102
"""
95103
Generates a unique name for an anonymous dependency.
@@ -112,7 +120,7 @@ def _generate_dependency_name(
112120

113121
def _create_wrapper(
114122
func: Callable,
115-
original_parameters: MappingProxyType[str, Parameter],
123+
original_parameters: dict[str, Parameter],
116124
) -> Callable:
117125
"""
118126
Creates a wrapper function that filters out dependency arguments.
@@ -128,15 +136,15 @@ def _create_wrapper(
128136
@wraps(func)
129137
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
130138
filtered_kwargs = {
131-
k: v for k, v in kwargs.items() if k in original_parameters
139+
k: v for k, v in kwargs.items() if k in original_parameters and not k.startswith("__dependency_")
132140
}
133141
return await func(*args, **filtered_kwargs)
134142
return async_wrapper
135143
else:
136144
@wraps(func)
137145
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
138146
filtered_kwargs = {
139-
k: v for k, v in kwargs.items() if k in original_parameters
147+
k: v for k, v in kwargs.items() if k in original_parameters and not k.startswith("__dependency_")
140148
}
141149
return func(*args, **filtered_kwargs)
142150
return sync_wrapper

tests/app.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from functools import wraps
2+
from inspect import signature
23
import logging
3-
from time import time
4+
from time import sleep, time
45
from fastapi import Depends, FastAPI, HTTPException, Header, Request, Response
56
from fastapi.responses import JSONResponse
67
from pydantic import BaseModel
7-
from requests import Session
88

99
from fastapi_decorators import add_dependencies
1010

@@ -18,6 +18,7 @@ class DataModel(BaseModel):
1818
value: int
1919

2020
rate_limit_store = {}
21+
cache_storage = {}
2122
error_log = []
2223
fake_db = {
2324
"access_token": "valid_token",
@@ -36,6 +37,12 @@ class DataModel(BaseModel):
3637
def get_db() -> dict:
3738
return fake_db
3839

40+
def get_cache() -> dict:
41+
return cache_storage
42+
43+
def get_crash_log_storage() -> list:
44+
return error_log
45+
3946
def get_rate_limit_store() -> dict:
4047
return rate_limit_store
4148

@@ -98,17 +105,42 @@ async def dependency(
98105
rate_limit_store[request_id] = calls_info
99106
return add_dependencies(Depends(dependency))
100107

108+
def cache_response(max_age: int = 5):
109+
def decorator(func):
110+
111+
# Wrap the endpoint after adding the get_cache dependency
112+
@add_dependencies(cache=Depends(get_cache))
113+
@wraps(func)
114+
def wrapper(*args, cache: dict, **kwargs):
115+
key = func.__name__
116+
117+
if key in cache:
118+
timestamp, data = cache[key]
119+
if time() - timestamp < max_age:
120+
# Cache hit
121+
return data
122+
123+
# Cache miss - call the endpoint as usual
124+
result = func(*args, **kwargs)
125+
126+
# Store the result in the cache
127+
cache[key] = time(), result
128+
return result
129+
130+
return wrapper
131+
return decorator
132+
101133
def handle_errors():
102134
def decorator(func):
135+
@add_dependencies(crash_logs = Depends(get_crash_log_storage))
103136
@wraps(func)
104-
async def wrapper(*args, **kwargs):
137+
def wrapper(*args, crash_logs: list, **kwargs):
105138
try:
106-
return await func(*args, **kwargs)
139+
return func(*args, **kwargs)
107140
except Exception as e:
108-
error_message = f"An error occurred: {str(e)}"
109-
logging.error(error_message)
110-
error_log.append({'error': error_message, 'function': func.__name__})
111-
return JSONResponse(status_code=500, content={"detail": error_message})
141+
crash_logs.append({ 'error': str(e), 'function': func.__name__ })
142+
return JSONResponse(status_code=500, content={ "detail": str(e) })
143+
112144
return wrapper
113145
return decorator
114146

@@ -185,9 +217,18 @@ def limited_endpoint():
185217
"""
186218
return {"message": "You have accessed a rate-limited endpoint"}
187219

220+
@app.get("/expensive-operation")
221+
@cache_response(max_age=5)
222+
def expensive_operation():
223+
"""
224+
Endpoint that is cached for 5 seconds.
225+
"""
226+
sleep(5)
227+
return {"data": time() }
228+
188229
@app.get("/may-fail")
189230
@handle_errors()
190-
async def may_fail_operation(should_fail: bool = False):
231+
def may_fail_operation(should_fail: bool = False):
191232
"""
192233
Endpoint that may raise exceptions.
193234
"""

tests/test_app.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from time import sleep
1+
from time import sleep, time
22
from fastapi.testclient import TestClient
3-
from app import app, fake_db, rate_limit_store, error_log
3+
from app import app, fake_db, rate_limit_store, cache_storage, error_log
44

55
client = TestClient(app)
66

@@ -99,6 +99,22 @@ def test_rate_limiting():
9999
response = client.get("/limited-endpoint")
100100
assert response.status_code == 200
101101

102+
def test_cached_response():
103+
cache_storage.clear()
104+
data1 = client.get("/expensive-operation").json()
105+
start = time()
106+
data2 = client.get("/expensive-operation").json()
107+
end = time()
108+
109+
assert end - start < 1 # The endpoint takes 5 seconds when not cached
110+
assert data1 == data2
111+
112+
# Wait for the cache to expire
113+
sleep(5)
114+
115+
data3 = client.get("/expensive-operation").json()
116+
assert data1 != data3
117+
102118
def test_error_handling_success():
103119
response = client.get("/may-fail?should_fail=false")
104120
assert response.status_code == 200
@@ -109,7 +125,7 @@ def test_error_handling_failure():
109125

110126
response = client.get("/may-fail?should_fail=true")
111127
assert response.status_code == 500
112-
assert "An error occurred" in response.json()["detail"]
128+
assert len(error_log) == 1
113129

114130
# Check that error was logged
115131
assert len(error_log) == 1

0 commit comments

Comments
 (0)