Skip to content

Commit 46b536f

Browse files
TLK-1864 agents deployments models refactoring - review fixes
1 parent 0cea965 commit 46b536f

File tree

2 files changed

+71
-51
lines changed

2 files changed

+71
-51
lines changed

src/backend/routers/agent.py

Lines changed: 53 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
AgentToolMetadata as AgentToolMetadataModel,
1515
)
1616
from backend.database_models.database import DBSessionDep
17-
from backend.routers.utils import get_deployment_model_from_agent
17+
from backend.routers.utils import get_deployment_model_from_agent, get_default_deployment_model
1818
from backend.schemas.agent import (
1919
Agent,
2020
AgentPublic,
@@ -53,6 +53,7 @@
5353
)
5454
router.name = RouterName.AGENT
5555

56+
5657
@router.post(
5758
"",
5859
response_model=AgentPublic,
@@ -62,9 +63,9 @@
6263
],
6364
)
6465
async def create_agent(
65-
session: DBSessionDep,
66-
agent: CreateAgentRequest,
67-
ctx: Context = Depends(get_context),
66+
session: DBSessionDep,
67+
agent: CreateAgentRequest,
68+
ctx: Context = Depends(get_context),
6869
) -> AgentPublic:
6970
"""
7071
Create an agent.
@@ -83,6 +84,7 @@ async def create_agent(
8384
logger = ctx.get_logger()
8485

8586
deployment_db, model_db = get_deployment_model_from_agent(agent, session)
87+
default_deployment_db, default_model_db = get_default_deployment_model(session)
8688
try:
8789
if deployment_db and model_db:
8890
agent_data = AgentModel(
@@ -94,8 +96,8 @@ async def create_agent(
9496
organization_id=agent.organization_id,
9597
tools=agent.tools,
9698
is_private=agent.is_private,
97-
deployment_id=deployment_db.id if deployment_db else None,
98-
model_id=model_db.id if model_db else None,
99+
deployment_id=deployment_db.id if deployment_db else default_deployment_db.id if default_deployment_db else None,
100+
model_id=model_db.id if model_db else default_model_db.id if default_model_db else None,
99101
)
100102

101103
created_agent = agent_crud.create_agent(session, agent_data)
@@ -117,13 +119,13 @@ async def create_agent(
117119

118120
@router.get("", response_model=list[AgentPublic])
119121
async def list_agents(
120-
*,
121-
offset: int = 0,
122-
limit: int = 100,
123-
session: DBSessionDep,
124-
visibility: AgentVisibility = AgentVisibility.ALL,
125-
organization_id: Optional[str] = None,
126-
ctx: Context = Depends(get_context),
122+
*,
123+
offset: int = 0,
124+
limit: int = 100,
125+
session: DBSessionDep,
126+
visibility: AgentVisibility = AgentVisibility.ALL,
127+
organization_id: Optional[str] = None,
128+
ctx: Context = Depends(get_context),
127129
) -> list[AgentPublic]:
128130
"""
129131
List all agents.
@@ -161,7 +163,7 @@ async def list_agents(
161163

162164
@router.get("/{agent_id}", response_model=AgentPublic)
163165
async def get_agent_by_id(
164-
agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context)
166+
agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context)
165167
) -> Agent:
166168
"""
167169
Args:
@@ -196,7 +198,7 @@ async def get_agent_by_id(
196198

197199
@router.get("/{agent_id}/deployments", response_model=list[DeploymentSchema])
198200
async def get_agent_deployment(
199-
agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context)
201+
agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context)
200202
) -> DeploymentSchema:
201203
"""
202204
Args:
@@ -228,10 +230,10 @@ async def get_agent_deployment(
228230
],
229231
)
230232
async def update_agent(
231-
agent_id: str,
232-
new_agent: UpdateAgentRequest,
233-
session: DBSessionDep,
234-
ctx: Context = Depends(get_context),
233+
agent_id: str,
234+
new_agent: UpdateAgentRequest,
235+
session: DBSessionDep,
236+
ctx: Context = Depends(get_context),
235237
) -> AgentPublic:
236238
"""
237239
Update an agent by ID.
@@ -285,9 +287,9 @@ async def update_agent(
285287

286288
@router.delete("/{agent_id}", response_model=DeleteAgent)
287289
async def delete_agent(
288-
agent_id: str,
289-
session: DBSessionDep,
290-
ctx: Context = Depends(get_context),
290+
agent_id: str,
291+
session: DBSessionDep,
292+
ctx: Context = Depends(get_context),
291293
) -> DeleteAgent:
292294
"""
293295
Delete an agent by ID.
@@ -319,10 +321,10 @@ async def delete_agent(
319321

320322

321323
async def handle_tool_metadata_update(
322-
agent: Agent,
323-
new_agent: Agent,
324-
session: DBSessionDep,
325-
ctx: Context = Depends(get_context),
324+
agent: Agent,
325+
new_agent: Agent,
326+
session: DBSessionDep,
327+
ctx: Context = Depends(get_context),
326328
) -> Agent:
327329
"""Update or create tool metadata for an agent.
328330
@@ -360,10 +362,10 @@ async def handle_tool_metadata_update(
360362

361363

362364
async def update_or_create_tool_metadata(
363-
agent: Agent,
364-
new_tool_metadata: AgentToolMetadata,
365-
session: DBSessionDep,
366-
ctx: Context = Depends(get_context),
365+
agent: Agent,
366+
new_tool_metadata: AgentToolMetadata,
367+
session: DBSessionDep,
368+
ctx: Context = Depends(get_context),
367369
) -> None:
368370
"""Update or create tool metadata for an agent.
369371
@@ -389,7 +391,7 @@ async def update_or_create_tool_metadata(
389391

390392
@router.get("/{agent_id}/tool-metadata", response_model=list[AgentToolMetadataPublic])
391393
async def list_agent_tool_metadata(
392-
agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context)
394+
agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context)
393395
) -> list[AgentToolMetadataPublic]:
394396
"""
395397
List all agent tool metadata by agent ID.
@@ -421,10 +423,10 @@ async def list_agent_tool_metadata(
421423
response_model=AgentToolMetadataPublic,
422424
)
423425
def create_agent_tool_metadata(
424-
session: DBSessionDep,
425-
agent_id: str,
426-
agent_tool_metadata: CreateAgentToolMetadataRequest,
427-
ctx: Context = Depends(get_context),
426+
session: DBSessionDep,
427+
agent_id: str,
428+
agent_tool_metadata: CreateAgentToolMetadataRequest,
429+
ctx: Context = Depends(get_context),
428430
) -> AgentToolMetadataPublic:
429431
"""
430432
Create an agent tool metadata.
@@ -470,11 +472,11 @@ def create_agent_tool_metadata(
470472

471473
@router.put("/{agent_id}/tool-metadata/{agent_tool_metadata_id}")
472474
async def update_agent_tool_metadata(
473-
agent_id: str,
474-
agent_tool_metadata_id: str,
475-
session: DBSessionDep,
476-
new_agent_tool_metadata: UpdateAgentToolMetadataRequest,
477-
ctx: Context = Depends(get_context),
475+
agent_id: str,
476+
agent_tool_metadata_id: str,
477+
session: DBSessionDep,
478+
new_agent_tool_metadata: UpdateAgentToolMetadataRequest,
479+
ctx: Context = Depends(get_context),
478480
) -> AgentToolMetadata:
479481
"""
480482
Update an agent tool metadata by ID.
@@ -514,10 +516,10 @@ async def update_agent_tool_metadata(
514516

515517
@router.delete("/{agent_id}/tool-metadata/{agent_tool_metadata_id}")
516518
async def delete_agent_tool_metadata(
517-
agent_id: str,
518-
agent_tool_metadata_id: str,
519-
session: DBSessionDep,
520-
ctx: Context = Depends(get_context),
519+
agent_id: str,
520+
agent_tool_metadata_id: str,
521+
session: DBSessionDep,
522+
ctx: Context = Depends(get_context),
521523
) -> DeleteAgentToolMetadata:
522524
"""
523525
Delete an agent tool metadata by ID.
@@ -556,9 +558,9 @@ async def delete_agent_tool_metadata(
556558

557559
@router.post("/batch_upload_file", response_model=list[UploadAgentFileResponse])
558560
async def batch_upload_file(
559-
session: DBSessionDep,
560-
files: list[FastAPIUploadFile] = RequestFile(...),
561-
ctx: Context = Depends(get_context),
561+
session: DBSessionDep,
562+
files: list[FastAPIUploadFile] = RequestFile(...),
563+
ctx: Context = Depends(get_context),
562564
) -> UploadAgentFileResponse:
563565
user_id = ctx.get_user_id()
564566

@@ -580,10 +582,10 @@ async def batch_upload_file(
580582

581583
@router.delete("/{agent_id}/files/{file_id}")
582584
async def delete_agent_file(
583-
agent_id: str,
584-
file_id: str,
585-
session: DBSessionDep,
586-
ctx: Context = Depends(get_context),
585+
agent_id: str,
586+
file_id: str,
587+
session: DBSessionDep,
588+
ctx: Context = Depends(get_context),
587589
) -> DeleteAgentFileResponse:
588590
"""
589591
Delete an agent file by ID.

src/backend/routers/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from backend.config.deployments import ModelDeploymentName
12
from backend.database_models.database import DBSessionDep
23
from backend.schemas.agent import Agent
34

@@ -21,3 +22,20 @@ def get_deployment_model_from_agent(agent: Agent, session: DBSessionDep):
2122
None,
2223
)
2324
return deployment_db, model_db
25+
26+
27+
def get_default_deployment_model(session: DBSessionDep):
28+
from backend.crud import deployment as deployment_crud
29+
30+
deployment_db = deployment_crud.get_deployment_by_name(session, ModelDeploymentName.CoherePlatform)
31+
model_db = None
32+
if deployment_db:
33+
model_db = next(
34+
(
35+
model
36+
for model in deployment_db.models
37+
if model.name == 'command-r-plus'
38+
),
39+
None,
40+
)
41+
return deployment_db, model_db

0 commit comments

Comments
 (0)