Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 72 additions & 5 deletions lib/crewai/src/crewai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,17 @@ def _handle_streaming_response(
if not tool_calls or not available_functions:
# Track token usage and log callbacks if available in streaming mode
if usage_info:
self._track_token_usage_internal(usage_info)
# Convert usage object to dict if needed
if hasattr(usage_info, "__dict__"):
usage_dict = {
"prompt_tokens": getattr(usage_info, "prompt_tokens", 0),
"completion_tokens": getattr(usage_info, "completion_tokens", 0),
"total_tokens": getattr(usage_info, "total_tokens", 0),
"cached_tokens": getattr(usage_info, "cached_tokens", 0),
}
else:
usage_dict = usage_info
self._track_token_usage_internal(usage_dict)
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)

if response_model and self.is_litellm:
Expand Down Expand Up @@ -964,7 +974,17 @@ def _handle_streaming_response(

# --- 10) Track token usage and log callbacks if available in streaming mode
if usage_info:
self._track_token_usage_internal(usage_info)
# Convert usage object to dict if needed
if hasattr(usage_info, "__dict__"):
usage_dict = {
"prompt_tokens": getattr(usage_info, "prompt_tokens", 0),
"completion_tokens": getattr(usage_info, "completion_tokens", 0),
"total_tokens": getattr(usage_info, "total_tokens", 0),
"cached_tokens": getattr(usage_info, "cached_tokens", 0),
}
else:
usage_dict = usage_info
self._track_token_usage_internal(usage_dict)
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)

# --- 11) Emit completion event and return response
Expand Down Expand Up @@ -1173,7 +1193,23 @@ def _handle_non_streaming_response(
0
].message
text_response = response_message.content or ""
# --- 3) Handle callbacks with usage info

# --- 3a) Track token usage internally
usage_info = getattr(response, "usage", None)
if usage_info:
# Convert usage object to dict if needed
if hasattr(usage_info, "__dict__"):
usage_dict = {
"prompt_tokens": getattr(usage_info, "prompt_tokens", 0),
"completion_tokens": getattr(usage_info, "completion_tokens", 0),
"total_tokens": getattr(usage_info, "total_tokens", 0),
"cached_tokens": getattr(usage_info, "cached_tokens", 0),
}
else:
usage_dict = usage_info
self._track_token_usage_internal(usage_dict)

# --- 3b) Handle callbacks with usage info
if callbacks and len(callbacks) > 0:
for callback in callbacks:
if hasattr(callback, "log_success_event"):
Expand Down Expand Up @@ -1293,10 +1329,24 @@ async def _ahandle_non_streaming_response(
].message
text_response = response_message.content or ""

# Track token usage internally
usage_info = getattr(response, "usage", None)
if usage_info:
# Convert usage object to dict if needed
if hasattr(usage_info, "__dict__"):
usage_dict = {
"prompt_tokens": getattr(usage_info, "prompt_tokens", 0),
"completion_tokens": getattr(usage_info, "completion_tokens", 0),
"total_tokens": getattr(usage_info, "total_tokens", 0),
"cached_tokens": getattr(usage_info, "cached_tokens", 0),
}
else:
usage_dict = usage_info
self._track_token_usage_internal(usage_dict)

if callbacks and len(callbacks) > 0:
for callback in callbacks:
if hasattr(callback, "log_success_event"):
usage_info = getattr(response, "usage", None)
if usage_info:
callback.log_success_event(
kwargs=params,
Expand Down Expand Up @@ -1381,7 +1431,10 @@ async def _ahandle_streaming_response(
if not isinstance(chunk.choices, type):
choices = chunk.choices

if hasattr(chunk, "usage") and chunk.usage is not None:
# Try to extract usage information if available
if isinstance(chunk, dict) and "usage" in chunk:
usage_info = chunk["usage"]
elif hasattr(chunk, "usage") and chunk.usage is not None:
usage_info = chunk.usage

if choices and len(choices) > 0:
Expand Down Expand Up @@ -1434,6 +1487,20 @@ async def _ahandle_streaming_response(
),
)

# Track token usage internally
if usage_info:
# Convert usage object to dict if needed
if hasattr(usage_info, "__dict__"):
usage_dict = {
"prompt_tokens": getattr(usage_info, "prompt_tokens", 0),
"completion_tokens": getattr(usage_info, "completion_tokens", 0),
"total_tokens": getattr(usage_info, "total_tokens", 0),
"cached_tokens": getattr(usage_info, "cached_tokens", 0),
}
else:
usage_dict = usage_info
self._track_token_usage_internal(usage_dict)

if callbacks and len(callbacks) > 0 and usage_info:
for callback in callbacks:
if hasattr(callback, "log_success_event"):
Expand Down
Loading
Loading