Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions python/valuecell/core/coordinate/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ async def close_session(self, session_id: str):
session_id,
Role.SYSTEM,
f"Session closed. {cancelled_count} tasks were cancelled.",
agent_name="orchestrator",
)

async def get_session_history(self, session_id: str):
Expand Down Expand Up @@ -289,6 +288,9 @@ async def _handle_new_request(
"""Handle a new user request"""
session_id = user_input.meta.session_id
thread_id = generate_thread_id()
yield self._response_factory.thread_started(
conversation_id=session_id, thread_id=thread_id
)

# Add user message to session
await self.session_manager.add_message(
Expand Down Expand Up @@ -384,6 +386,9 @@ async def _continue_planning(
original_user_input = context.get_metadata("original_user_input")
thread_id = generate_thread_id()
context.thread_id = thread_id
yield self._response_factory.thread_started(
conversation_id=session_id, thread_id=thread_id
)

if not all([planning_task, original_user_input]):
yield self._response_factory.plan_failed(
Expand Down Expand Up @@ -508,7 +513,6 @@ async def _execute_plan_with_input_support(
session_id,
Role.AGENT,
agent_responses[task.agent_name],
agent_name=task.agent_name,
)
agent_responses[task.agent_name] = ""

Expand Down Expand Up @@ -614,7 +618,7 @@ async def _save_remaining_responses(self, session_id: str, agent_responses: dict
for agent_name, full_response in agent_responses.items():
if full_response.strip():
await self.session_manager.add_message(
session_id, Role.AGENT, full_response, agent_name=agent_name
session_id, Role.AGENT, full_response
)


Expand Down
10 changes: 10 additions & 0 deletions python/valuecell/core/coordinate/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
SystemFailedResponse,
TaskCompletedResponse,
TaskFailedResponse,
ThreadStartedResponse,
ToolCallPayload,
ToolCallResponse,
UnifiedResponseData,
Expand All @@ -28,6 +29,15 @@ def conversation_started(self, conversation_id: str) -> ConversationStartedRespo
data=UnifiedResponseData(conversation_id=conversation_id)
)

def thread_started(
self, conversation_id: str, thread_id: str
) -> ThreadStartedResponse:
return ThreadStartedResponse(
data=UnifiedResponseData(
conversation_id=conversation_id, thread_id=thread_id
)
)

def system_failed(self, conversation_id: str, content: str) -> SystemFailedResponse:
return SystemFailedResponse(
data=UnifiedResponseData(
Expand Down
8 changes: 4 additions & 4 deletions python/valuecell/core/coordinate/tests/test_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,9 @@ async def test_planner_error(
async for chunk in orchestrator.process_user_input(sample_user_input):
out.append(chunk)

assert len(out) == 1
assert "(Error)" in out[0].data.payload.content
assert "Planning failed" in out[0].data.payload.content
assert len(out) == 2
assert "(Error)" in out[1].data.payload.content
assert "Planning failed" in out[1].data.payload.content


@pytest.mark.asyncio
Expand All @@ -328,7 +328,7 @@ async def test_agent_connection_error(
async for chunk in orchestrator.process_user_input(sample_user_input):
out.append(chunk)

assert any("(Error)" in c.data.payload.content for c in out)
assert any("(Error)" in c.data.payload.content for c in out if c.data.payload)


@pytest.mark.asyncio
Expand Down
8 changes: 8 additions & 0 deletions python/valuecell/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def clear_desired_agent(self) -> None:

class SystemResponseEvent(str, Enum):
CONVERSATION_STARTED = "conversation_started"
THREAD_STARTED = "thread_started"
PLAN_REQUIRE_USER_INPUT = "plan_require_user_input"
PLAN_FAILED = "plan_failed"
TASK_FAILED = "task_failed"
Expand Down Expand Up @@ -172,6 +173,13 @@ class ConversationStartedResponse(BaseResponse):
)


class ThreadStartedResponse(BaseResponse):
event: Literal[SystemResponseEvent.THREAD_STARTED] = Field(
SystemResponseEvent.THREAD_STARTED,
description="The event type of the response",
)


class PlanRequireUserInputResponse(BaseResponse):
event: Literal[SystemResponseEvent.PLAN_REQUIRE_USER_INPUT] = Field(
SystemResponseEvent.PLAN_REQUIRE_USER_INPUT,
Expand Down