diff --git a/lib/crewai/src/crewai/events/types/llm_events.py b/lib/crewai/src/crewai/events/types/llm_events.py index c6db9405db..a5fc6b637a 100644 --- a/lib/crewai/src/crewai/events/types/llm_events.py +++ b/lib/crewai/src/crewai/events/types/llm_events.py @@ -9,6 +9,7 @@ class LLMEventBase(BaseEvent): from_task: Any | None = None from_agent: Any | None = None + message_id: str | None = None def __init__(self, **data): if data.get("from_task"): diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index 2e2684ebe1..25e0084ab7 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -18,6 +18,7 @@ TypedDict, cast, ) +import uuid from dotenv import load_dotenv import httpx @@ -532,6 +533,7 @@ def _handle_streaming_response( from_task: Task | None = None, from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, + message_id: str | None = None, ) -> Any: """Handle a streaming response from the LLM. @@ -626,6 +628,7 @@ def _handle_streaming_response( available_functions=available_functions, from_task=from_task, from_agent=from_agent, + message_id=message_id, ) if result is not None: @@ -646,6 +649,7 @@ def _handle_streaming_response( chunk=chunk_content, from_task=from_task, from_agent=from_agent, + message_id=message_id, ), ) # --- 4) Fallback to non-streaming if no content received @@ -763,6 +767,7 @@ def _handle_streaming_response( from_task=from_task, from_agent=from_agent, messages=params["messages"], + message_id=message_id, ) return structured_response @@ -772,11 +777,12 @@ def _handle_streaming_response( from_task=from_task, from_agent=from_agent, messages=params["messages"], + message_id=message_id, ) return full_response # --- 9) Handle tool calls if present - tool_result = self._handle_tool_call(tool_calls, available_functions) + tool_result = self._handle_tool_call(tool_calls, available_functions, from_task, from_agent, message_id) if tool_result is not None: return tool_result @@ -792,6 +798,7 @@ def _handle_streaming_response( from_task=from_task, from_agent=from_agent, messages=params["messages"], + message_id=message_id, ) return full_response @@ -810,13 +817,14 @@ def _handle_streaming_response( from_task=from_task, from_agent=from_agent, messages=params["messages"], + message_id=message_id, ) return full_response crewai_event_bus.emit( self, event=LLMCallFailedEvent( - error=str(e), from_task=from_task, from_agent=from_agent + error=str(e), from_task=from_task, from_agent=from_agent, message_id=message_id ), ) raise Exception(f"Failed to get streaming response: {e!s}") from e @@ -828,6 +836,7 @@ def _handle_streaming_tool_calls( available_functions: dict[str, Any] | None = None, from_task: Task | None = None, from_agent: Agent | None = None, + message_id: str | None = None, ) -> Any: for tool_call in tool_calls: current_tool_accumulator = accumulated_tool_args[tool_call.index] @@ -847,6 +856,7 @@ def _handle_streaming_tool_calls( chunk=tool_call.function.arguments, from_task=from_task, from_agent=from_agent, + message_id=message_id, ), ) @@ -861,6 +871,9 @@ def _handle_streaming_tool_calls( return self._handle_tool_call( [current_tool_accumulator], available_functions, + from_task, + from_agent, + message_id, ) except json.JSONDecodeError: continue @@ -914,6 +927,7 @@ def _handle_non_streaming_response( from_task: Task | None = None, from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, + message_id: str | None = None, ) -> str | Any: """Handle a non-streaming response from the LLM. @@ -954,6 +968,7 @@ def _handle_non_streaming_response( from_task=from_task, from_agent=from_agent, messages=params["messages"], + message_id=message_id, ) return structured_response @@ -982,6 +997,7 @@ def _handle_non_streaming_response( from_task=from_task, from_agent=from_agent, messages=params["messages"], + message_id=message_id, ) return structured_response @@ -1013,6 +1029,7 @@ def _handle_non_streaming_response( from_task=from_task, from_agent=from_agent, messages=params["messages"], + message_id=message_id, ) return text_response @@ -1022,7 +1039,7 @@ def _handle_non_streaming_response( # --- 7) Handle tool calls if present tool_result = self._handle_tool_call( - tool_calls, available_functions, from_task, from_agent + tool_calls, available_functions, from_task, from_agent, message_id ) if tool_result is not None: return tool_result @@ -1033,6 +1050,7 @@ def _handle_non_streaming_response( from_task=from_task, from_agent=from_agent, messages=params["messages"], + message_id=message_id, ) return text_response @@ -1042,6 +1060,7 @@ def _handle_tool_call( available_functions: dict[str, Any] | None = None, from_task: Task | None = None, from_agent: Agent | None = None, + message_id: str | None = None, ) -> Any: """Handle a tool call from the LLM. @@ -1101,6 +1120,7 @@ def _handle_tool_call( call_type=LLMCallType.TOOL_CALL, from_task=from_task, from_agent=from_agent, + message_id=message_id, ) return result except Exception as e: @@ -1111,7 +1131,7 @@ def _handle_tool_call( logging.error(f"Error executing function '{function_name}': {e}") crewai_event_bus.emit( self, - event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"), + event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}", message_id=message_id), ) crewai_event_bus.emit( self, @@ -1161,6 +1181,8 @@ def call( ValueError: If response format is not supported LLMContextLengthExceededError: If input exceeds model's context limit """ + message_id = uuid.uuid4().hex + crewai_event_bus.emit( self, event=LLMCallStartedEvent( @@ -1171,6 +1193,7 @@ def call( from_task=from_task, from_agent=from_agent, model=self.model, + message_id=message_id, ), ) @@ -1202,6 +1225,7 @@ def call( from_task=from_task, from_agent=from_agent, response_model=response_model, + message_id=message_id, ) return self._handle_non_streaming_response( @@ -1211,6 +1235,7 @@ def call( from_task=from_task, from_agent=from_agent, response_model=response_model, + message_id=message_id, ) except LLMContextLengthExceededError: # Re-raise LLMContextLengthExceededError as it should be handled @@ -1248,7 +1273,7 @@ def call( crewai_event_bus.emit( self, event=LLMCallFailedEvent( - error=str(e), from_task=from_task, from_agent=from_agent + error=str(e), from_task=from_task, from_agent=from_agent, message_id=message_id ), ) raise @@ -1260,6 +1285,7 @@ def _handle_emit_call_events( from_task: Task | None = None, from_agent: Agent | None = None, messages: str | list[LLMMessage] | None = None, + message_id: str | None = None, ) -> None: """Handle the events for the LLM call. @@ -1269,6 +1295,7 @@ def _handle_emit_call_events( from_task: Optional task object from_agent: Optional agent object messages: Optional messages object + message_id: Optional message identifier """ crewai_event_bus.emit( self, @@ -1279,6 +1306,7 @@ def _handle_emit_call_events( from_task=from_task, from_agent=from_agent, model=self.model, + message_id=message_id, ), ) diff --git a/lib/crewai/tests/test_llm_message_id.py b/lib/crewai/tests/test_llm_message_id.py new file mode 100644 index 0000000000..6b63647d72 --- /dev/null +++ b/lib/crewai/tests/test_llm_message_id.py @@ -0,0 +1,178 @@ +import threading +from unittest.mock import Mock, patch + +import pytest + +from crewai.agent import Agent +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.llm_events import ( + LLMCallCompletedEvent, + LLMCallStartedEvent, + LLMStreamChunkEvent, +) +from crewai.llm import LLM +from crewai.task import Task + + +@pytest.fixture +def base_agent(): + return Agent( + role="test_agent", + llm="gpt-4o-mini", + goal="Test message_id", + backstory="You are a test assistant", + ) + + +@pytest.fixture +def base_task(base_agent): + return Task( + description="Test message_id", + expected_output="test", + agent=base_agent, + ) + + +def test_llm_events_have_unique_message_ids_for_different_calls(base_agent, base_task): + """Test that different LLM calls have different message_ids""" + received_events = [] + event_received = threading.Event() + + @crewai_event_bus.on(LLMCallStartedEvent) + def handle_llm_started(source, event): + received_events.append(event) + if len(received_events) >= 2: + event_received.set() + + llm = LLM(model="gpt-4o-mini") + + with patch("litellm.completion") as mock_completion: + mock_completion.return_value = Mock( + choices=[Mock(message=Mock(content="Response 1", tool_calls=None))], + usage=Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30), + ) + + llm.call("Test message 1", from_task=base_task, from_agent=base_agent) + llm.call("Test message 2", from_task=base_task, from_agent=base_agent) + + assert event_received.wait(timeout=5), "Timeout waiting for LLM started events" + assert len(received_events) >= 2 + assert received_events[0].message_id is not None + assert received_events[1].message_id is not None + assert received_events[0].message_id != received_events[1].message_id + + +def test_streaming_chunks_have_same_message_id(base_agent, base_task): + """Test that all chunks from the same streaming call have the same message_id""" + received_events = [] + lock = threading.Lock() + all_events_received = threading.Event() + + @crewai_event_bus.on(LLMStreamChunkEvent) + def handle_stream_chunk(source, event): + with lock: + received_events.append(event) + if len(received_events) >= 3: + all_events_received.set() + + llm = LLM(model="gpt-4o-mini", stream=True) + + def mock_stream_generator(): + yield Mock( + choices=[Mock(delta=Mock(content="Hello", tool_calls=None))], + usage=None, + ) + yield Mock( + choices=[Mock(delta=Mock(content=" ", tool_calls=None))], + usage=None, + ) + yield Mock( + choices=[Mock(delta=Mock(content="World", tool_calls=None))], + usage=Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30), + ) + + with patch("litellm.completion", return_value=mock_stream_generator()): + llm.call("Test streaming", from_task=base_task, from_agent=base_agent) + + assert all_events_received.wait(timeout=5), "Timeout waiting for stream chunk events" + assert len(received_events) >= 3 + + message_ids = [event.message_id for event in received_events] + assert all(mid is not None for mid in message_ids) + assert len(set(message_ids)) == 1, "All chunks should have the same message_id" + + +def test_completed_event_has_same_message_id_as_started(base_agent, base_task): + """Test that Started and Completed events have the same message_id""" + received_events = {"started": None, "completed": None} + lock = threading.Lock() + all_events_received = threading.Event() + + @crewai_event_bus.on(LLMCallStartedEvent) + def handle_started(source, event): + with lock: + received_events["started"] = event + if received_events["completed"] is not None: + all_events_received.set() + + @crewai_event_bus.on(LLMCallCompletedEvent) + def handle_completed(source, event): + with lock: + received_events["completed"] = event + if received_events["started"] is not None: + all_events_received.set() + + llm = LLM(model="gpt-4o-mini") + + with patch("litellm.completion") as mock_completion: + mock_completion.return_value = Mock( + choices=[Mock(message=Mock(content="Response", tool_calls=None))], + usage=Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30), + ) + + llm.call("Test message", from_task=base_task, from_agent=base_agent) + + assert all_events_received.wait(timeout=5), "Timeout waiting for events" + assert received_events["started"] is not None + assert received_events["completed"] is not None + assert received_events["started"].message_id is not None + assert received_events["completed"].message_id is not None + assert received_events["started"].message_id == received_events["completed"].message_id + + +def test_multiple_calls_same_agent_task_have_different_message_ids(base_agent, base_task): + """Test that multiple calls from the same agent/task have different message_ids""" + received_started_events = [] + lock = threading.Lock() + all_events_received = threading.Event() + + @crewai_event_bus.on(LLMCallStartedEvent) + def handle_started(source, event): + with lock: + received_started_events.append(event) + if len(received_started_events) >= 3: + all_events_received.set() + + llm = LLM(model="gpt-4o-mini") + + with patch("litellm.completion") as mock_completion: + mock_completion.return_value = Mock( + choices=[Mock(message=Mock(content="Response", tool_calls=None))], + usage=Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30), + ) + + llm.call("Message 1", from_task=base_task, from_agent=base_agent) + llm.call("Message 2", from_task=base_task, from_agent=base_agent) + llm.call("Message 3", from_task=base_task, from_agent=base_agent) + + assert all_events_received.wait(timeout=5), "Timeout waiting for events" + assert len(received_started_events) >= 3 + + message_ids = [event.message_id for event in received_started_events] + assert all(mid is not None for mid in message_ids) + assert len(set(message_ids)) == 3, "Each call should have a unique message_id" + + task_ids = [event.task_id for event in received_started_events] + agent_ids = [event.agent_id for event in received_started_events] + assert len(set(task_ids)) == 1, "All calls should have the same task_id" + assert len(set(agent_ids)) == 1, "All calls should have the same agent_id"