Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/crewai/src/crewai/events/types/llm_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
class LLMEventBase(BaseEvent):
from_task: Any | None = None
from_agent: Any | None = None
message_id: str | None = None

def __init__(self, **data):
if data.get("from_task"):
Expand Down
38 changes: 33 additions & 5 deletions lib/crewai/src/crewai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
TypedDict,
cast,
)
import uuid

from dotenv import load_dotenv
import httpx
Expand Down Expand Up @@ -532,6 +533,7 @@ def _handle_streaming_response(
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
message_id: str | None = None,
) -> Any:
"""Handle a streaming response from the LLM.

Expand Down Expand Up @@ -626,6 +628,7 @@ def _handle_streaming_response(
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
message_id=message_id,
)

if result is not None:
Expand All @@ -646,6 +649,7 @@ def _handle_streaming_response(
chunk=chunk_content,
from_task=from_task,
from_agent=from_agent,
message_id=message_id,
),
)
# --- 4) Fallback to non-streaming if no content received
Expand Down Expand Up @@ -763,6 +767,7 @@ def _handle_streaming_response(
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
message_id=message_id,
)
return structured_response

Expand All @@ -772,11 +777,12 @@ def _handle_streaming_response(
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
message_id=message_id,
)
return full_response

# --- 9) Handle tool calls if present
tool_result = self._handle_tool_call(tool_calls, available_functions)
tool_result = self._handle_tool_call(tool_calls, available_functions, from_task, from_agent, message_id)
if tool_result is not None:
return tool_result

Expand All @@ -792,6 +798,7 @@ def _handle_streaming_response(
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
message_id=message_id,
)
return full_response

Expand All @@ -810,13 +817,14 @@ def _handle_streaming_response(
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
message_id=message_id,
)
return full_response

crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
error=str(e), from_task=from_task, from_agent=from_agent, message_id=message_id
),
)
raise Exception(f"Failed to get streaming response: {e!s}") from e
Expand All @@ -828,6 +836,7 @@ def _handle_streaming_tool_calls(
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
message_id: str | None = None,
) -> Any:
for tool_call in tool_calls:
current_tool_accumulator = accumulated_tool_args[tool_call.index]
Expand All @@ -847,6 +856,7 @@ def _handle_streaming_tool_calls(
chunk=tool_call.function.arguments,
from_task=from_task,
from_agent=from_agent,
message_id=message_id,
),
)

Expand All @@ -861,6 +871,9 @@ def _handle_streaming_tool_calls(
return self._handle_tool_call(
[current_tool_accumulator],
available_functions,
from_task,
from_agent,
message_id,
)
except json.JSONDecodeError:
continue
Expand Down Expand Up @@ -914,6 +927,7 @@ def _handle_non_streaming_response(
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
message_id: str | None = None,
) -> str | Any:
"""Handle a non-streaming response from the LLM.

Expand Down Expand Up @@ -954,6 +968,7 @@ def _handle_non_streaming_response(
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
message_id=message_id,
)
return structured_response

Expand Down Expand Up @@ -982,6 +997,7 @@ def _handle_non_streaming_response(
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
message_id=message_id,
)
return structured_response

Expand Down Expand Up @@ -1013,6 +1029,7 @@ def _handle_non_streaming_response(
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
message_id=message_id,
)
return text_response

Expand All @@ -1022,7 +1039,7 @@ def _handle_non_streaming_response(

# --- 7) Handle tool calls if present
tool_result = self._handle_tool_call(
tool_calls, available_functions, from_task, from_agent
tool_calls, available_functions, from_task, from_agent, message_id
)
if tool_result is not None:
return tool_result
Expand All @@ -1033,6 +1050,7 @@ def _handle_non_streaming_response(
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
message_id=message_id,
)
return text_response

Expand All @@ -1042,6 +1060,7 @@ def _handle_tool_call(
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
message_id: str | None = None,
) -> Any:
"""Handle a tool call from the LLM.

Expand Down Expand Up @@ -1101,6 +1120,7 @@ def _handle_tool_call(
call_type=LLMCallType.TOOL_CALL,
from_task=from_task,
from_agent=from_agent,
message_id=message_id,
)
return result
except Exception as e:
Expand All @@ -1111,7 +1131,7 @@ def _handle_tool_call(
logging.error(f"Error executing function '{function_name}': {e}")
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}", message_id=message_id),
)
crewai_event_bus.emit(
self,
Expand Down Expand Up @@ -1161,6 +1181,8 @@ def call(
ValueError: If response format is not supported
LLMContextLengthExceededError: If input exceeds model's context limit
"""
message_id = uuid.uuid4().hex

crewai_event_bus.emit(
self,
event=LLMCallStartedEvent(
Expand All @@ -1171,6 +1193,7 @@ def call(
from_task=from_task,
from_agent=from_agent,
model=self.model,
message_id=message_id,
),
)

Expand Down Expand Up @@ -1202,6 +1225,7 @@ def call(
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
message_id=message_id,
)

return self._handle_non_streaming_response(
Expand All @@ -1211,6 +1235,7 @@ def call(
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
message_id=message_id,
)
except LLMContextLengthExceededError:
# Re-raise LLMContextLengthExceededError as it should be handled
Expand Down Expand Up @@ -1248,7 +1273,7 @@ def call(
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
error=str(e), from_task=from_task, from_agent=from_agent, message_id=message_id
),
)
raise
Expand All @@ -1260,6 +1285,7 @@ def _handle_emit_call_events(
from_task: Task | None = None,
from_agent: Agent | None = None,
messages: str | list[LLMMessage] | None = None,
message_id: str | None = None,
) -> None:
"""Handle the events for the LLM call.

Expand All @@ -1269,6 +1295,7 @@ def _handle_emit_call_events(
from_task: Optional task object
from_agent: Optional agent object
messages: Optional messages object
message_id: Optional message identifier
"""
crewai_event_bus.emit(
self,
Expand All @@ -1279,6 +1306,7 @@ def _handle_emit_call_events(
from_task=from_task,
from_agent=from_agent,
model=self.model,
message_id=message_id,
),
)

Expand Down
Loading
Loading