Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/google/adk/features/_feature_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class FeatureConfig:
FeatureStage.WIP, default_on=False
),
FeatureName.PROGRESSIVE_SSE_STREAMING: FeatureConfig(
FeatureStage.WIP, default_on=False
FeatureStage.EXPERIMENTAL, default_on=True
),
FeatureName.PUBSUB_TOOLSET: FeatureConfig(
FeatureStage.EXPERIMENTAL, default_on=True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,6 @@
from google.adk.runners import InMemoryRunner
from google.adk.utils.streaming_utils import StreamingResponseAggregator
from google.genai import types
import pytest


@pytest.fixture(autouse=True)
def reset_env(monkeypatch):
monkeypatch.setenv("ADK_ENABLE_PROGRESSIVE_SSE_STREAMING", "1")
yield
monkeypatch.delenv("ADK_ENABLE_PROGRESSIVE_SSE_STREAMING")


def get_weather(location: str) -> dict[str, Any]:
Expand Down
140 changes: 60 additions & 80 deletions tests/unittests/models/test_google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,10 +1116,16 @@ async def mock_coro():
)
]

# Should have only 1 response (no aggregated content generated)
assert len(responses) == 1
# Verify it's a function call, not text
# With progressive SSE streaming enabled by default, we get 2 responses:
# 1. Partial response with function call
# 2. Final aggregated response with function call
assert len(responses) == 2
# First response is partial
assert responses[0].partial is True
assert responses[0].content.parts[0].function_call is not None
# Second response is the final aggregated response
assert responses[1].partial is False
assert responses[1].content.parts[0].function_call is not None


@pytest.mark.asyncio
Expand Down Expand Up @@ -1194,37 +1200,33 @@ async def mock_coro():
)
]

# Should have multiple responses:
# With progressive SSE streaming enabled, we get 4 responses:
# 1. Partial text "First text"
# 2. Aggregated "First text" when function call interrupts
# 3. Function call
# 4. Partial text " second text"
# 5. Final aggregated " second text"
assert len(responses) == 5
# 2. Partial function call
# 3. Partial text " second text"
# 4. Final aggregated response with all parts (text + FC + text)
assert len(responses) == 4

# First partial text
assert responses[0].partial is True
assert responses[0].content.parts[0].text == "First text"

# Aggregated first text (when function call interrupts)
assert responses[1].content.parts[0].text == "First text"
assert (
responses[1].partial is None
) # Aggregated responses don't have partial flag

# Function call
assert responses[2].content.parts[0].function_call is not None
assert responses[2].content.parts[0].function_call.name == "test_func"
# Partial function call
assert responses[1].partial is True
assert responses[1].content.parts[0].function_call is not None
assert responses[1].content.parts[0].function_call.name == "test_func"

# Second partial text
assert responses[3].partial is True
assert responses[3].content.parts[0].text == " second text"
# Partial second text
assert responses[2].partial is True
assert responses[2].content.parts[0].text == " second text"

# Final aggregated text with error info
assert responses[4].content.parts[0].text == " second text"
assert (
responses[4].error_code is None
) # STOP finish reason should have None error_code
# Final aggregated response with all parts
assert responses[3].partial is False
assert len(responses[3].content.parts) == 3
assert responses[3].content.parts[0].text == "First text"
assert responses[3].content.parts[1].function_call.name == "test_func"
assert responses[3].content.parts[2].text == " second text"
assert responses[3].error_code is None # STOP finish reason


@pytest.mark.asyncio
Expand Down Expand Up @@ -1376,28 +1378,27 @@ async def mock_coro():
)
]

# Should properly separate thought and regular text across aggregations
assert len(responses) > 5 # Multiple partial + aggregated responses
# With progressive SSE streaming, we get 6 responses:
# 5 partial responses + 1 final aggregated response
assert len(responses) == 6

# Verify we get both thought and regular text parts in aggregated responses
aggregated_responses = [
r
for r in responses
if r.partial is None and r.content and len(r.content.parts) > 1
]
assert (
len(aggregated_responses) > 0
) # Should have at least one aggregated response with multiple parts
# All but the last should be partial
for i in range(5):
assert responses[i].partial is True

# Final aggregated response should have both thought and text
# Final aggregated response should have all parts
final_response = responses[-1]
assert (
final_response.error_code is None
) # STOP finish reason should have None error_code
assert len(final_response.content.parts) == 2 # thought part + text part
assert final_response.partial is False
assert final_response.error_code is None # STOP finish reason
# Final response aggregates: thought + text + FC + thought + text
assert len(final_response.content.parts) == 5
assert final_response.content.parts[0].thought is True
assert "More thinking..." in final_response.content.parts[0].text
assert final_response.content.parts[1].text == " and conclusion"
assert "Thinking..." in final_response.content.parts[0].text
assert final_response.content.parts[1].text == "Here's my answer"
assert final_response.content.parts[2].function_call.name == "lookup"
assert final_response.content.parts[3].thought is True
assert "More thinking..." in final_response.content.parts[3].text
assert final_response.content.parts[4].text == " and conclusion"


@pytest.mark.asyncio
Expand Down Expand Up @@ -1491,44 +1492,23 @@ async def mock_coro():
)
]

# Find the aggregated text responses (non-partial, text-only)
aggregated_text_responses = [
r
for r in responses
if (
r.partial is None
and r.content
and r.content.parts
and r.content.parts[0].text
and not r.content.parts[0].function_call
)
]

# Should have two separate text aggregations: "First chunk" and "Second chunk"
assert len(aggregated_text_responses) >= 2
# With progressive SSE streaming, we get 6 responses:
# 5 partial responses + 1 final aggregated response
assert len(responses) == 6

# First aggregation should contain "First chunk"
first_aggregation = aggregated_text_responses[0]
assert first_aggregation.content.parts[0].text == "First chunk"
# All but the last should be partial
for i in range(5):
assert responses[i].partial is True

# Final aggregation should contain "Second chunk" and have error info
final_aggregation = aggregated_text_responses[-1]
assert final_aggregation.content.parts[0].text == "Second chunk"
assert (
final_aggregation.error_code is None
) # STOP finish reason should have None error_code

# Verify the function call is preserved between aggregations
function_call_responses = [
r
for r in responses
if (r.content and r.content.parts and r.content.parts[0].function_call)
]
assert len(function_call_responses) == 1
assert (
function_call_responses[0].content.parts[0].function_call.name
== "divide"
)
# Final response should be aggregated with all parts
final_response = responses[-1]
assert final_response.partial is False
assert final_response.error_code is None # STOP finish reason
# Final response aggregates: text1 + text2 + FC + text3 + text4
assert len(final_response.content.parts) == 3
assert final_response.content.parts[0].text == "First chunk"
assert final_response.content.parts[1].function_call.name == "divide"
assert final_response.content.parts[2].text == "Second chunk"


@pytest.mark.asyncio
Expand Down
Loading