-
Notifications
You must be signed in to change notification settings - Fork 159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(BA-615): Collect metrics for the RPC server #3555
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Collect metrics for the RPC server |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
python_sources(name="src") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from typing import Optional, Self | ||
|
||
from prometheus_client import Counter, Histogram | ||
|
||
|
||
class RPCMetricObserver: | ||
_instance: Optional[Self] = None | ||
|
||
_rpc_requests: Counter | ||
_rpc_failure_requests: Counter | ||
_rpc_request_duration: Histogram | ||
|
||
def __init__(self) -> None: | ||
self._rpc_requests = Counter( | ||
name="backendai_rpc_requests_total", | ||
documentation="Number of RPC requests", | ||
labelnames=["method"], | ||
) | ||
self._rpc_failure_requests = Counter( | ||
name="backendai_rpc_failure_requests_total", | ||
documentation="Number of failed RPC requests", | ||
labelnames=["method", "exception"], | ||
) | ||
self._rpc_request_duration = Histogram( | ||
name="backendai_rpc_request_duration_seconds", | ||
documentation="Duration of RPC requests", | ||
labelnames=["method"], | ||
buckets=[0.1, 1, 10, 30, 60, 300, 600], | ||
) | ||
|
||
@classmethod | ||
def instance(cls) -> Self: | ||
if cls._instance is None: | ||
cls._instance = cls() | ||
return cls._instance | ||
|
||
def observe_rpc_request_success(self, *, method: str, duration: float) -> None: | ||
self._rpc_requests.labels(method=method).inc() | ||
self._rpc_request_duration.labels(method=method).observe(duration) | ||
|
||
def observe_rpc_request_failure( | ||
self, *, method: str, duration: float, exception: Exception | ||
) -> None: | ||
exception_name = exception.__class__.__name__ | ||
self._rpc_requests.labels(method=method).inc() | ||
self._rpc_failure_requests.labels(method=method, exception=exception_name).inc() | ||
self._rpc_request_duration.labels(method=method).observe(duration) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
import signal | ||
import ssl | ||
import sys | ||
import time | ||
from collections import OrderedDict, defaultdict | ||
from datetime import datetime, timezone | ||
from ipaddress import IPv4Address, IPv6Address, ip_network | ||
|
@@ -50,6 +51,7 @@ | |
from trafaret.dataerror import DataError as TrafaretDataError | ||
from zmq.auth.certs import load_certificate | ||
|
||
from ai.backend.agent.metrics.metric import RPCMetricObserver | ||
from ai.backend.common import config, identity, msgpack, utils | ||
from ai.backend.common.auth import AgentAuthHandler, PublicKey, SecretKey | ||
from ai.backend.common.bgtask import ProgressReporter | ||
|
@@ -162,15 +164,18 @@ async def _inner(self: AgentRPCServer, *args, **kwargs): | |
|
||
class RPCFunctionRegistry: | ||
functions: Set[str] | ||
_metric_observer: RPCMetricObserver | ||
|
||
def __init__(self) -> None: | ||
self.functions = set() | ||
self._metric_observer = RPCMetricObserver.instance() | ||
|
||
def __call__( | ||
self, | ||
meth: Callable[..., Coroutine[None, None, Any]], | ||
) -> Callable[[AgentRPCServer, RPCMessage], Coroutine[None, None, Any]]: | ||
@functools.wraps(meth) | ||
@_collect_metrics(self._metric_observer) | ||
async def _inner(self_: AgentRPCServer, request: RPCMessage) -> Any: | ||
try: | ||
if request.body is None: | ||
|
@@ -195,6 +200,33 @@ async def _inner(self_: AgentRPCServer, request: RPCMessage) -> Any: | |
return _inner | ||
|
||
|
||
def _collect_metrics(observer: RPCMetricObserver) -> Callable: | ||
def decorator(meth: Callable) -> Callable[[AgentRPCServer, RPCMessage], Any]: | ||
@functools.wraps(meth) | ||
async def _inner(self: AgentRPCServer, *args, **kwargs) -> Any: | ||
Comment on lines
+203
to
+206
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can apply more strict type hints here but then we have to update the type hints of |
||
start_time = time.perf_counter() | ||
try: | ||
res = await meth(self, *args, **kwargs) | ||
duration = time.perf_counter() - start_time | ||
observer.observe_rpc_request_success( | ||
method=meth.__name__, | ||
duration=duration, | ||
) | ||
return res | ||
except Exception as e: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about handling |
||
duration = time.perf_counter() - start_time | ||
observer.observe_rpc_request_failure( | ||
method=meth.__name__, | ||
duration=duration, | ||
exception=e, | ||
) | ||
raise | ||
|
||
return _inner | ||
|
||
return decorator | ||
|
||
|
||
class AgentRPCServer(aobject): | ||
rpc_function: ClassVar[RPCFunctionRegistry] = RPCFunctionRegistry() | ||
rpc_auth_manager_public_key: Optional[PublicKey] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it used as a singleton object? Is it expected to use this observer in other context?