Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(BA-615): Collect metrics for the RPC server #3555

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3555.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Collect metrics for the RPC server
1 change: 1 addition & 0 deletions src/ai/backend/agent/metrics/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_sources(name="src")
47 changes: 47 additions & 0 deletions src/ai/backend/agent/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Optional, Self

from prometheus_client import Counter, Histogram


class RPCMetricObserver:
_instance: Optional[Self] = None

_rpc_requests: Counter
_rpc_failure_requests: Counter
_rpc_request_duration: Histogram

def __init__(self) -> None:
self._rpc_requests = Counter(
name="backendai_rpc_requests_total",
documentation="Number of RPC requests",
labelnames=["method"],
)
self._rpc_failure_requests = Counter(
name="backendai_rpc_failure_requests_total",
documentation="Number of failed RPC requests",
labelnames=["method", "exception"],
)
self._rpc_request_duration = Histogram(
name="backendai_rpc_request_duration_seconds",
documentation="Duration of RPC requests",
labelnames=["method"],
buckets=[0.1, 1, 10, 30, 60, 300, 600],
)

@classmethod
def instance(cls) -> Self:
if cls._instance is None:
cls._instance = cls()
return cls._instance

def observe_rpc_request_success(self, *, method: str, duration: float) -> None:
self._rpc_requests.labels(method=method).inc()
self._rpc_request_duration.labels(method=method).observe(duration)

def observe_rpc_request_failure(
self, *, method: str, duration: float, exception: Exception
) -> None:
exception_name = exception.__class__.__name__
self._rpc_requests.labels(method=method).inc()
self._rpc_failure_requests.labels(method=method, exception=exception_name).inc()
self._rpc_request_duration.labels(method=method).observe(duration)
32 changes: 32 additions & 0 deletions src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import signal
import ssl
import sys
import time
from collections import OrderedDict, defaultdict
from datetime import datetime, timezone
from ipaddress import IPv4Address, IPv6Address, ip_network
Expand Down Expand Up @@ -50,6 +51,7 @@
from trafaret.dataerror import DataError as TrafaretDataError
from zmq.auth.certs import load_certificate

from ai.backend.agent.metrics.metric import RPCMetricObserver
from ai.backend.common import config, identity, msgpack, utils
from ai.backend.common.auth import AgentAuthHandler, PublicKey, SecretKey
from ai.backend.common.bgtask import ProgressReporter
Expand Down Expand Up @@ -162,15 +164,18 @@ async def _inner(self: AgentRPCServer, *args, **kwargs):

class RPCFunctionRegistry:
functions: Set[str]
_metric_observer: RPCMetricObserver

def __init__(self) -> None:
self.functions = set()
self._metric_observer = RPCMetricObserver.instance()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it used as a singleton object? Is it expected to use this observer in other context?


def __call__(
self,
meth: Callable[..., Coroutine[None, None, Any]],
) -> Callable[[AgentRPCServer, RPCMessage], Coroutine[None, None, Any]]:
@functools.wraps(meth)
@_collect_metrics(self._metric_observer)
async def _inner(self_: AgentRPCServer, request: RPCMessage) -> Any:
try:
if request.body is None:
Expand All @@ -195,6 +200,33 @@ async def _inner(self_: AgentRPCServer, request: RPCMessage) -> Any:
return _inner


def _collect_metrics(observer: RPCMetricObserver) -> Callable:
def decorator(meth: Callable) -> Callable[[AgentRPCServer, RPCMessage], Any]:
@functools.wraps(meth)
async def _inner(self: AgentRPCServer, *args, **kwargs) -> Any:
Comment on lines +203 to +206
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can apply more strict type hints here but then we have to update the type hints of RPCFunctionRegistry.__call__().
Let's leave it as a minor issue

start_time = time.perf_counter()
try:
res = await meth(self, *args, **kwargs)
duration = time.perf_counter() - start_time
observer.observe_rpc_request_success(
method=meth.__name__,
duration=duration,
)
return res
except Exception as e:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about handling BaseException here?

duration = time.perf_counter() - start_time
observer.observe_rpc_request_failure(
method=meth.__name__,
duration=duration,
exception=e,
)
raise

return _inner

return decorator


class AgentRPCServer(aobject):
rpc_function: ClassVar[RPCFunctionRegistry] = RPCFunctionRegistry()
rpc_auth_manager_public_key: Optional[PublicKey]
Expand Down
Loading