Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(BA-460): Cache gpu_alloc_map in Redis, and Add RescanGPUAllocMaps mutation #3293

Open
wants to merge 33 commits into
base: topic/06-13-feat_support_scanning_gpu_allocation
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
a595ba4
feat: Cache gpu_alloc_map, and Add ScanGPUAllocMap mutation
jopemachine Dec 24, 2024
cbaf50b
chore: Add news fragment
jopemachine Dec 24, 2024
2d20c44
chore: Add news fragment
jopemachine Dec 24, 2024
1f7243c
chore: fix typo
jopemachine Dec 24, 2024
b342a19
chore: Improve news fragment
jopemachine Dec 24, 2024
f9fc023
fix: Add milestone comment
jopemachine Dec 24, 2024
b919dc9
fix: Wrong impl of AgentRegistry.scan_gpu_alloc_map
jopemachine Dec 24, 2024
61913ba
fix: Add `extra_fixtures`
jopemachine Dec 24, 2024
0decec0
feat: Add `test_scan_gpu_alloc_maps` test case
jopemachine Dec 24, 2024
50591f3
feat: Add update call count check
jopemachine Dec 24, 2024
a64a49d
fix: Improve `test_scan_gpu_alloc_maps`
jopemachine Dec 26, 2024
375f677
fix: Improve `test_scan_gpu_alloc_maps`
jopemachine Dec 26, 2024
6268367
chore: Rename variables
jopemachine Dec 26, 2024
cb69aa9
fix: `ScanGPUAllocMaps` -> `RescanGPUAllocMaps`
jopemachine Dec 26, 2024
aa53962
fix: Broken test
jopemachine Dec 26, 2024
c7222f3
fix: Remove useless `_default_host`
jopemachine Dec 26, 2024
91ba036
chore: Rename news fragment
jopemachine Dec 26, 2024
abd06d9
feat: Improve error handling
jopemachine Dec 26, 2024
17d8e5a
fix: Improve exception handling and test case
jopemachine Dec 26, 2024
4282bf3
fix: Replace useless `mock_agent_registry_ctx` with local_config's `r…
jopemachine Dec 26, 2024
c2e2e09
fix: Wrong reference to `redis_stat`
jopemachine Dec 26, 2024
a4369a7
docs: Add description about agent_id
jopemachine Dec 26, 2024
16eb8a8
chore: update GraphQL schema dump
jopemachine Dec 26, 2024
64c33cf
feat: Call agent rpc call in parallel
jopemachine Dec 26, 2024
3c16914
fix: Update milestone
jopemachine Jan 8, 2025
afc1740
fix: Update milestone
jopemachine Jan 8, 2025
ab139c1
fix: Update milestone
jopemachine Jan 24, 2025
8379c07
misc: `25.3.1`
jopemachine Feb 19, 2025
3d346a6
feat: Add `UUIDFloatMap` custom scalar type
jopemachine Feb 20, 2025
2215c6d
chore: update api schema dump
jopemachine Feb 20, 2025
95d6405
fix: Try to fix CI
jopemachine Feb 20, 2025
e358417
fix: Broken CI
jopemachine Feb 21, 2025
264ae0a
fix: Broken test
jopemachine Feb 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3293.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cache `gpu_alloc_map` in Redis, and Add `RescanGPUAllocMaps` mutation for update the `gpu_alloc_map`s.
11 changes: 11 additions & 0 deletions docs/manager/graphql-reference/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1897,6 +1897,12 @@
This action cannot be undone.
"""
purge_user(email: String!, props: PurgeUserInput!): PurgeUser

"""Added in 25.3.1."""
rescan_gpu_alloc_maps(

Check notice on line 1902 in docs/manager/graphql-reference/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Field 'rescan_gpu_alloc_maps' was added to object type 'Mutations'

Field 'rescan_gpu_alloc_maps' was added to object type 'Mutations'
"""Agent ID to rescan GPU alloc map"""
agent_id: String!
): RescanGPUAllocMaps
create_keypair(props: KeyPairInput!, user_id: String!): CreateKeyPair
modify_keypair(access_key: String!, props: ModifyKeyPairInput!): ModifyKeyPair
delete_keypair(access_key: String!): DeleteKeyPair
Expand Down Expand Up @@ -2370,6 +2376,11 @@
purge_shared_vfolders: Boolean
}

"""Added in 25.3.1."""
type RescanGPUAllocMaps {

Check notice on line 2380 in docs/manager/graphql-reference/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Type 'RescanGPUAllocMaps' was added

Type 'RescanGPUAllocMaps' was added
task_id: UUID
}

type CreateKeyPair {
ok: Boolean
msg: String
Expand Down
2 changes: 2 additions & 0 deletions src/ai/backend/manager/models/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
AgentSummary,
AgentSummaryList,
ModifyAgent,
RescanGPUAllocMaps,
)
from .gql_models.container_registry import (
CreateContainerRegistryQuota,
Expand Down Expand Up @@ -282,6 +283,7 @@ class Mutations(graphene.ObjectType):
modify_user = ModifyUser.Field()
delete_user = DeleteUser.Field()
purge_user = PurgeUser.Field()
rescan_gpu_alloc_maps = RescanGPUAllocMaps.Field(description="Added in 25.3.1.")

# admin only
create_keypair = CreateKeyPair.Field()
Expand Down
87 changes: 80 additions & 7 deletions src/ai/backend/manager/models/gql_models/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import json
import logging
import uuid
from collections.abc import Iterable, Mapping, Sequence
from typing import (
Expand All @@ -18,13 +20,14 @@
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection

from ai.backend.common import msgpack, redis_helper
from ai.backend.common.bgtask import ProgressReporter
from ai.backend.common.types import (
AccessKey,
AgentId,
BinarySize,
HardwareMetadata,
)
from ai.backend.manager.models.gql_models.base import UUIDFloatMap
from ai.backend.logging.utils import BraceStyleAdapter

from ..agent import (
AgentRow,
Expand Down Expand Up @@ -56,12 +59,15 @@
from ..rbac.context import ClientContext
from ..rbac.permission_defs import AgentPermission
from ..user import UserRole, users
from .base import UUIDFloatMap
from .fields import AgentPermissionField
from .kernel import KernelConnection, KernelNode

if TYPE_CHECKING:
from ..gql import GraphQueryContext

log = BraceStyleAdapter(logging.getLogger(__spec__.name))

__all__ = (
"Agent",
"AgentNode",
Expand Down Expand Up @@ -96,6 +102,17 @@
}


async def _resolve_gpu_alloc_map(ctx: GraphQueryContext, agent_id: AgentId) -> dict[str, float]:
raw_alloc_map = await redis_helper.execute(
ctx.redis_stat, lambda r: r.get(f"gpu_alloc_map.{agent_id}")
)

if raw_alloc_map:
alloc_map = json.loads(raw_alloc_map)
return UUIDFloatMap.parse_value({k: float(v) for k, v in alloc_map.items()})
return {}


class AgentNode(graphene.ObjectType):
class Meta:
interfaces = (AsyncNode,)
Expand Down Expand Up @@ -180,9 +197,8 @@ async def resolve_live_stat(self, info: graphene.ResolveInfo) -> Any:
loader = ctx.dataloader_manager.get_loader_by_func(ctx, self.batch_load_live_stat)
return await loader.load(self.id)

async def resolve_gpu_alloc_map(self, info: graphene.ResolveInfo) -> Mapping[str, int]:
ctx: GraphQueryContext = info.context
return await ctx.registry.scan_gpu_alloc_map(self.id)
async def resolve_gpu_alloc_map(self, info: graphene.ResolveInfo) -> dict[str, float]:
return await _resolve_gpu_alloc_map(info.context, self.id)

async def resolve_hardware_metadata(
self,
Expand Down Expand Up @@ -434,9 +450,8 @@ async def resolve_container_count(self, info: graphene.ResolveInfo) -> int:
loader = ctx.dataloader_manager.get_loader_by_func(ctx, Agent.batch_load_container_count)
return await loader.load(self.id)

async def resolve_gpu_alloc_map(self, info: graphene.ResolveInfo) -> Mapping[str, int]:
ctx: GraphQueryContext = info.context
return await ctx.registry.scan_gpu_alloc_map(self.id)
async def resolve_gpu_alloc_map(self, info: graphene.ResolveInfo) -> dict[str, float]:
return await _resolve_gpu_alloc_map(info.context, self.id)

_queryfilter_fieldspec: Mapping[str, FieldSpecItem] = {
"id": ("id", None),
Expand Down Expand Up @@ -879,3 +894,61 @@ async def mutate(

update_query = sa.update(agents).values(data).where(agents.c.id == id)
return await simple_db_mutate(cls, graph_ctx, update_query)


class RescanGPUAllocMaps(graphene.Mutation):
allowed_roles = (UserRole.SUPERADMIN,)

class Meta:
description = "Added in 25.3.1."

class Arguments:
agent_id = graphene.String(
description="Agent ID to rescan GPU alloc map",
required=True,
)

task_id = graphene.UUID()

@classmethod
@privileged_mutation(
UserRole.SUPERADMIN,
lambda id, **kwargs: (None, id),
)
async def mutate(
cls,
root,
info: graphene.ResolveInfo,
agent_id: str,
) -> RescanGPUAllocMaps:
log.info("rescanning GPU alloc maps")
graph_ctx: GraphQueryContext = info.context

async def _rescan_alloc_map_task(reporter: ProgressReporter) -> None:
await reporter.update(message=f"Agent {agent_id} GPU alloc map scanning...")

reporter_msg = ""
try:
alloc_map: Mapping[str, Any] = await graph_ctx.registry.scan_gpu_alloc_map(
AgentId(agent_id)
)
key = f"gpu_alloc_map.{agent_id}"
await redis_helper.execute(
graph_ctx.registry.redis_stat,
lambda r: r.set(name=key, value=json.dumps(alloc_map)),
)
except Exception as e:
reporter_msg = f"Failed to scan GPU alloc map for agent {agent_id}: {str(e)}"
log.error(reporter_msg)
else:
reporter_msg = f"Agent {agent_id} GPU alloc map scanned."

await reporter.update(
increment=1,
message=reporter_msg,
)

await reporter.update(message="GPU alloc map scanning completed")

task_id = await graph_ctx.background_task_manager.start(_rescan_alloc_map_task)
return RescanGPUAllocMaps(task_id=task_id)
13 changes: 12 additions & 1 deletion tests/manager/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ def local_config(
redis_addr = redis_container[1]
postgres_addr = postgres_container[1]

build_root = Path(os.environ["BACKEND_BUILD_ROOT"])

# Establish a self-contained config.
cfg = LocalConfig({
**etcd_config_iv.check({
Expand Down Expand Up @@ -208,6 +210,7 @@ def local_config(
"service-addr": HostPortPair("127.0.0.1", 29100 + get_parallel_slot() * 10),
"allowed-plugins": set(),
"disabled-plugins": set(),
"rpc-auth-manager-keypair": f"{build_root}/fixtures/manager/manager.key_secret",
},
"pyroscope": {
"enabled": False,
Expand Down Expand Up @@ -266,7 +269,15 @@ def etcd_fixture(
"volumes": {
"_mount": str(vfolder_mount),
"_fsprefix": str(vfolder_fsprefix),
"_default_host": str(vfolder_host),
"default_host": str(vfolder_host),
"proxies": {
"local": {
"client_api": "http://127.0.0.1:6021",
"manager_api": "https://127.0.0.1:6022",
"secret": "some-secret-shared-with-storage-proxy",
"ssl_verify": "false",
}
},
},
"nodes": {},
"config": {
Expand Down
178 changes: 178 additions & 0 deletions tests/manager/models/gql_models/test_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import asyncio
import json
from unittest.mock import AsyncMock, patch

import attr
import pytest
from graphene import Schema
from graphene.test import Client

from ai.backend.common import redis_helper
from ai.backend.common.events import BgtaskDoneEvent, EventDispatcher
from ai.backend.common.types import AgentId
from ai.backend.manager.api.context import RootContext
from ai.backend.manager.models.agent import AgentStatus
from ai.backend.manager.models.gql import GraphQueryContext, Mutations, Queries
from ai.backend.manager.server import (
agent_registry_ctx,
background_task_ctx,
database_ctx,
event_dispatcher_ctx,
hook_plugin_ctx,
monitoring_ctx,
network_plugin_ctx,
redis_ctx,
shared_config_ctx,
storage_manager_ctx,
)


@pytest.fixture(scope="module")
def client() -> Client:
return Client(Schema(query=Queries, mutation=Mutations, auto_camelcase=False))


def get_graphquery_context(root_context: RootContext) -> GraphQueryContext:
return GraphQueryContext(
schema=None, # type: ignore
dataloader_manager=None, # type: ignore
local_config=None, # type: ignore
shared_config=None, # type: ignore
etcd=None, # type: ignore
user={"domain": "default", "role": "superadmin"},
access_key="AKIAIOSFODNN7EXAMPLE",
db=root_context.db, # type: ignore
redis_stat=None, # type: ignore
redis_image=None, # type: ignore
redis_live=None, # type: ignore
manager_status=None, # type: ignore
known_slot_types=None, # type: ignore
background_task_manager=root_context.background_task_manager, # type: ignore
storage_manager=None, # type: ignore
registry=root_context.registry, # type: ignore
idle_checker_host=None, # type: ignore
network_plugin_ctx=None, # type: ignore
services_ctx=None, # type: ignore
)


EXTRA_FIXTURES = {
"agents": [
{
"id": "i-ag1",
"status": AgentStatus.ALIVE.name,
"scaling_group": "default",
"schedulable": True,
"region": "local",
"available_slots": {},
"occupied_slots": {},
"addr": "tcp://127.0.0.1:6011",
"public_host": "127.0.0.1",
"version": "24.12.0a1",
"architecture": "x86_64",
"compute_plugins": {},
}
]
}


@patch("ai.backend.manager.registry.AgentRegistry.scan_gpu_alloc_map", new_callable=AsyncMock)
@pytest.mark.asyncio
@pytest.mark.timeout(60)
@pytest.mark.parametrize(
"test_case, extra_fixtures",
[
(
{
"mock_agent_responses": [
{
"00000000-0000-0000-0000-000000000001": "10.00",
"00000000-0000-0000-0000-000000000002": "5.00",
},
],
"expected": {
"redis": [
{
"00000000-0000-0000-0000-000000000001": "10.00",
"00000000-0000-0000-0000-000000000002": "5.00",
},
],
},
},
EXTRA_FIXTURES,
),
],
)
async def test_scan_gpu_alloc_maps(
mock_agent_responses,
client,
local_config,
etcd_fixture,
database_fixture,
create_app_and_client,
test_case,
extra_fixtures,
):
test_app, _ = await create_app_and_client(
[
shared_config_ctx,
database_ctx,
redis_ctx,
monitoring_ctx,
hook_plugin_ctx,
event_dispatcher_ctx,
storage_manager_ctx,
network_plugin_ctx,
agent_registry_ctx,
background_task_ctx,
],
[],
)

root_ctx: RootContext = test_app["_root.context"]
dispatcher: EventDispatcher = root_ctx.event_dispatcher
done_handler_ctx = {}
done_event = asyncio.Event()

async def done_sub(
context: None,
source: AgentId,
event: BgtaskDoneEvent,
) -> None:
update_body = attr.asdict(event) # type: ignore
done_handler_ctx.update(**update_body)
done_event.set()

dispatcher.subscribe(BgtaskDoneEvent, None, done_sub)

mock_agent_responses.side_effect = test_case["mock_agent_responses"]

context = get_graphquery_context(root_ctx)
query = """
mutation ($agent_id: String!) {
rescan_gpu_alloc_maps (agent_id: $agent_id) {
task_id
}
}
"""

res = await client.execute_async(
query,
context_value=context,
variables={
"agent_id": "i-ag1",
},
)

await done_event.wait()

assert str(done_handler_ctx["task_id"]) == res["data"]["rescan_gpu_alloc_maps"]["task_id"]
alloc_map_keys = [f"gpu_alloc_map.{agent['id']}" for agent in extra_fixtures["agents"]]
raw_alloc_map_cache = await redis_helper.execute(
root_ctx.redis_stat,
lambda r: r.mget(*alloc_map_keys),
)
alloc_map_cache = [
json.loads(stat) if stat is not None else None for stat in raw_alloc_map_cache
]
assert alloc_map_cache == test_case["expected"]["redis"]
Loading