Skip to content

Commit c52a34a

Browse files
authored
Add support for Litestar (#13)
* Add ApitallyPlugin for Litestar * Add identify_consumer_callback argument and fix validation error capture * Add tests * Fix test coverage * Update readme * Clean up * Fix typing * Add litestar to text matrix in CI * Move start_time definition down * Add filter_openapi_paths argument * Use contextlib.suppress * Improve typing for ApitallyClient * Fix getting correct path from Request object
1 parent e38abfa commit c52a34a

File tree

7 files changed

+729
-16
lines changed

7 files changed

+729
-16
lines changed

.github/workflows/tests.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ jobs:
6666
- django-ninja django
6767
- django-ninja==0.22.* django
6868
- django-ninja==0.18.0 django
69+
- litestar
70+
- litestar==2.6.1
71+
- litestar==2.0.1
6972

7073
steps:
7174
- uses: actions/checkout@v4

README.md

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ frameworks:
3030
- [Flask](https://docs.apitally.io/frameworks/flask)
3131
- [Django Ninja](https://docs.apitally.io/frameworks/django-ninja)
3232
- [Django REST Framework](https://docs.apitally.io/frameworks/django-rest-framework)
33+
- [Litestar](https://docs.apitally.io/frameworks/litestar)
3334

3435
Learn more about Apitally on our 🌎 [website](https://apitally.io) or check out
3536
the 📚 [documentation](https://docs.apitally.io).
@@ -50,7 +51,8 @@ example:
5051
pip install apitally[fastapi]
5152
```
5253

53-
The available extras are: `fastapi`, `starlette`, `flask` and `django`.
54+
The available extras are: `fastapi`, `starlette`, `flask`, `django` and
55+
`litestar`.
5456

5557
## Usage
5658

@@ -112,6 +114,27 @@ APITALLY_MIDDLEWARE = {
112114
}
113115
```
114116

117+
### Litestar
118+
119+
This is an example of how to add the Apitally plugin to a Litestar application.
120+
For further instructions, see our
121+
[setup guide for Litestar](https://docs.apitally.io/frameworks/litestar).
122+
123+
```python
124+
from litestar import Litestar
125+
from apitally.litestar import ApitallyPlugin
126+
127+
app = Litestar(
128+
route_handlers=[...],
129+
plugins=[
130+
ApitallyPlugin(
131+
client_id="your-client-id",
132+
env="dev", # or "prod" etc.
133+
),
134+
]
135+
)
136+
```
137+
115138
## Getting help
116139

117140
If you need help please

apitally/client/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import re
55
import threading
66
import time
7+
from abc import ABC
78
from collections import Counter
89
from dataclasses import dataclass
910
from math import floor
@@ -26,16 +27,16 @@
2627
TApitallyClient = TypeVar("TApitallyClient", bound="ApitallyClientBase")
2728

2829

29-
class ApitallyClientBase:
30+
class ApitallyClientBase(ABC):
3031
_instance: Optional[ApitallyClientBase] = None
3132
_lock = threading.Lock()
3233

33-
def __new__(cls, *args, **kwargs) -> ApitallyClientBase:
34+
def __new__(cls: Type[TApitallyClient], *args, **kwargs) -> TApitallyClient:
3435
if cls._instance is None:
3536
with cls._lock:
3637
if cls._instance is None:
3738
cls._instance = super().__new__(cls)
38-
return cls._instance
39+
return cast(TApitallyClient, cls._instance)
3940

4041
def __init__(self, client_id: str, env: str) -> None:
4142
if hasattr(self, "client_id"):

apitally/litestar.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import contextlib
2+
import json
3+
import sys
4+
import time
5+
from importlib.metadata import version
6+
from typing import Callable, Dict, List, Optional
7+
8+
from litestar.app import DEFAULT_OPENAPI_CONFIG, Litestar
9+
from litestar.config.app import AppConfig
10+
from litestar.connection import Request
11+
from litestar.datastructures import Headers
12+
from litestar.enums import ScopeType
13+
from litestar.handlers import HTTPRouteHandler
14+
from litestar.plugins import InitPluginProtocol
15+
from litestar.types import ASGIApp, Message, Receive, Scope, Send
16+
17+
from apitally.client.asyncio import ApitallyClient
18+
19+
20+
__all__ = ["ApitallyPlugin"]
21+
22+
23+
class ApitallyPlugin(InitPluginProtocol):
24+
def __init__(
25+
self,
26+
client_id: str,
27+
env: str = "dev",
28+
app_version: Optional[str] = None,
29+
filter_openapi_paths: bool = True,
30+
identify_consumer_callback: Optional[Callable[[Request], Optional[str]]] = None,
31+
) -> None:
32+
self.client = ApitallyClient(client_id=client_id, env=env)
33+
self.app_version = app_version
34+
self.filter_openapi_paths = filter_openapi_paths
35+
self.identify_consumer_callback = identify_consumer_callback
36+
self.openapi_path: Optional[str] = None
37+
38+
def on_app_init(self, app_config: AppConfig) -> AppConfig:
39+
app_config.on_startup.append(self.on_startup)
40+
app_config.middleware.append(self.middleware_factory)
41+
return app_config
42+
43+
def on_startup(self, app: Litestar) -> None:
44+
openapi_config = app.openapi_config or DEFAULT_OPENAPI_CONFIG
45+
self.openapi_path = openapi_config.openapi_controller.path
46+
47+
app_info = {
48+
"openapi": _get_openapi(app),
49+
"paths": [route for route in _get_routes(app) if not self.filter_path(route["path"])],
50+
"versions": _get_versions(self.app_version),
51+
"client": "python:litestar",
52+
}
53+
self.client.set_app_info(app_info)
54+
self.client.start_sync_loop()
55+
56+
def middleware_factory(self, app: ASGIApp) -> ASGIApp:
57+
async def middleware(scope: Scope, receive: Receive, send: Send) -> None:
58+
if scope["type"] == "http" and scope["method"] != "OPTIONS":
59+
request = Request(scope)
60+
response_status = 0
61+
response_time = 0.0
62+
response_headers = Headers()
63+
response_body = b""
64+
start_time = time.perf_counter()
65+
66+
async def send_wrapper(message: Message) -> None:
67+
nonlocal response_time, response_status, response_headers, response_body
68+
if message["type"] == "http.response.start":
69+
response_time = time.perf_counter() - start_time
70+
response_status = message["status"]
71+
response_headers = Headers(message["headers"])
72+
elif message["type"] == "http.response.body" and response_status == 400:
73+
response_body += message["body"]
74+
await send(message)
75+
76+
await app(scope, receive, send_wrapper)
77+
self.add_request(
78+
request=request,
79+
response_status=response_status,
80+
response_time=response_time,
81+
response_headers=response_headers,
82+
response_body=response_body,
83+
)
84+
else:
85+
await app(scope, receive, send) # pragma: no cover
86+
87+
return middleware
88+
89+
def add_request(
90+
self,
91+
request: Request,
92+
response_status: int,
93+
response_time: float,
94+
response_headers: Headers,
95+
response_body: bytes,
96+
) -> None:
97+
if response_status < 100 or not request.route_handler.paths:
98+
return # pragma: no cover
99+
path = self.get_path(request)
100+
if path is None or self.filter_path(path):
101+
return
102+
consumer = self.get_consumer(request)
103+
self.client.request_counter.add_request(
104+
consumer=consumer,
105+
method=request.method,
106+
path=path,
107+
status_code=response_status,
108+
response_time=response_time,
109+
request_size=request.headers.get("Content-Length"),
110+
response_size=response_headers.get("Content-Length"),
111+
)
112+
if response_status == 400 and response_body and len(response_body) < 4096:
113+
with contextlib.suppress(json.JSONDecodeError):
114+
parsed_body = json.loads(response_body)
115+
if (
116+
isinstance(parsed_body, dict)
117+
and "detail" in parsed_body
118+
and isinstance(parsed_body["detail"], str)
119+
and "validation" in parsed_body["detail"].lower()
120+
and "extra" in parsed_body
121+
and isinstance(parsed_body["extra"], list)
122+
):
123+
self.client.validation_error_counter.add_validation_errors(
124+
consumer=consumer,
125+
method=request.method,
126+
path=path,
127+
detail=[
128+
{
129+
"loc": [error.get("source", "body")] + error["key"].split("."),
130+
"msg": error["message"],
131+
"type": "",
132+
}
133+
for error in parsed_body["extra"]
134+
if "key" in error and "message" in error
135+
],
136+
)
137+
138+
def get_path(self, request: Request) -> Optional[str]:
139+
path: List[str] = []
140+
for layer in request.route_handler.ownership_layers:
141+
if isinstance(layer, HTTPRouteHandler):
142+
if len(layer.paths) == 0:
143+
return None # pragma: no cover
144+
path.append(list(layer.paths)[0].lstrip("/"))
145+
else:
146+
path.append(layer.path.lstrip("/"))
147+
return "/" + "/".join(filter(None, path))
148+
149+
def filter_path(self, path: str) -> bool:
150+
if self.filter_openapi_paths and self.openapi_path:
151+
return path == self.openapi_path or path.startswith(self.openapi_path + "/")
152+
return False # pragma: no cover
153+
154+
def get_consumer(self, request: Request) -> Optional[str]:
155+
if hasattr(request.state, "consumer_identifier"):
156+
return str(request.state.consumer_identifier)
157+
if self.identify_consumer_callback is not None:
158+
consumer_identifier = self.identify_consumer_callback(request)
159+
if consumer_identifier is not None:
160+
return str(consumer_identifier)
161+
return None
162+
163+
164+
def _get_openapi(app: Litestar) -> str:
165+
schema = app.openapi_schema.to_schema()
166+
return json.dumps(schema)
167+
168+
169+
def _get_routes(app: Litestar) -> List[Dict[str, str]]:
170+
return [
171+
{"method": method, "path": route.path}
172+
for route in app.routes
173+
for method in route.methods
174+
if route.scope_type == ScopeType.HTTP and method != "OPTIONS"
175+
]
176+
177+
178+
def _get_versions(app_version: Optional[str]) -> Dict[str, str]:
179+
versions = {
180+
"python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
181+
"apitally": version("apitally"),
182+
"litestar": version("litestar"),
183+
}
184+
if app_version:
185+
versions["app"] = app_version
186+
return versions

0 commit comments

Comments
 (0)