Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix structured outputs on various models #2109

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cookbook/models/azure/openai/structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ class MovieScript(BaseModel):
model=AzureOpenAI(id="gpt-4o-mini"),
description="You help people write movie scripts.",
response_model=MovieScript,
# debug_mode=True,
structured_outputs=True,
)

# Get the response in a variable
# run: RunResponse = agent.run("New York")
# pprint(run.content)
run: RunResponse = agent.run("New York")
pprint(run.content)

agent.print_response("New York")
# agent.print_response("New York")
6 changes: 2 additions & 4 deletions cookbook/models/cohere/structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,5 @@ class MovieScript(BaseModel):
)

# Get the response in a variable
# json_mode_response: RunResponse = json_mode_agent.run("New York")
# pprint(json_mode_response.content)

json_mode_agent.print_response("New York")
json_mode_response: RunResponse = json_mode_agent.run("New York")
pprint(json_mode_response.content)
6 changes: 3 additions & 3 deletions cookbook/models/deepseek/structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class MovieScript(BaseModel):
)

# Get the response in a variable
# json_mode_response: RunResponse = json_mode_agent.run("New York")
# pprint(json_mode_response.content)
json_mode_response: RunResponse = json_mode_agent.run("New York")
pprint(json_mode_response.content)

json_mode_agent.print_response("New York")
# json_mode_agent.print_response("New York")
6 changes: 3 additions & 3 deletions cookbook/models/fireworks/structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MovieScript(BaseModel):
)

# Get the response in a variable
# response: RunResponse = agent.run("New York")
# pprint(json_mode_response.content)
response: RunResponse = agent.run("New York")
pprint(response.content)

agent.print_response("New York")
# agent.print_response("New York")
6 changes: 3 additions & 3 deletions cookbook/models/groq/structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MovieScript(BaseModel):
)

# Get the response in a variable
# run: RunResponse = json_mode_agent.run("New York")
# pprint(run.content)
run: RunResponse = json_mode_agent.run("New York")
pprint(run.content)

json_mode_agent.print_response("New York")
# json_mode_agent.print_response("New York")
11 changes: 6 additions & 5 deletions cookbook/models/mistral/structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,20 @@ class MovieScript(BaseModel):
)


json_mode_agent = Agent(
structured_output_agent = Agent(
model=MistralChat(
id="mistral-large-latest",
),
tools=[DuckDuckGoTools()],
description="You help people write movie scripts.",
response_model=MovieScript,
structured_outputs=True,
show_tool_calls=True,
debug_mode=True,
# debug_mode=True,
)

# Get the response in a variable
# json_mode_response: RunResponse = json_mode_agent.run("New York")
# pprint(json_mode_response.content)
structured_output_response: RunResponse = structured_output_agent.run("New York")
pprint(structured_output_response.content)

json_mode_agent.print_response("Find a cool movie idea about London and write it.")
# json_mode_agent.print_response("Find a cool movie idea about London and write it.")
13 changes: 7 additions & 6 deletions cookbook/models/ollama/structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from agno.agent import Agent
from agno.models.ollama import Ollama
from agno.run.response import RunResponse
from pydantic import BaseModel, Field
from rich.pretty import pprint # noqa


class MovieScript(BaseModel):
Expand Down Expand Up @@ -34,12 +36,11 @@ class MovieScript(BaseModel):
)

# Run the agent synchronously
structured_output_agent.print_response("Llamas ruling the world")
structured_output_response: RunResponse = structured_output_agent.run(
"Llamas ruling the world"
)
pprint(structured_output_response.content)


# Run the agent asynchronously
async def run_agents_async():
await structured_output_agent.aprint_response("Llamas ruling the world")


asyncio.run(run_agents_async())
asyncio.run(structured_output_agent.aprint_response("Llamas ruling the world"))
42 changes: 42 additions & 0 deletions cookbook/models/xai/structured_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import asyncio
from typing import List

from agno.agent import Agent
from agno.models.xai.xai import xAI
from agno.run.response import RunResponse
from pydantic import BaseModel, Field
from rich.pretty import pprint # noqa


class MovieScript(BaseModel):
name: str = Field(..., description="Give a name to this movie")
setting: str = Field(
..., description="Provide a nice setting for a blockbuster movie."
)
ending: str = Field(
...,
description="Ending of the movie. If not available, provide a happy ending.",
)
genre: str = Field(
...,
description="Genre of the movie. If not available, select action, thriller or romantic comedy.",
)
characters: List[str] = Field(..., description="Name of characters for this movie.")
storyline: str = Field(
..., description="3 sentence storyline for the movie. Make it exciting!"
)


# Agent that returns a structured output
structured_output_agent = Agent(
model=xAI(id="grok-2-latest"),
description="You write movie scripts.",
response_model=MovieScript,
structured_outputs=True,
)

# Run the agent synchronously
structured_output_response: RunResponse = structured_output_agent.run(
"Llamas ruling the world"
)
pprint(structured_output_response.content)
1 change: 1 addition & 0 deletions libs/agno/agno/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2457,6 +2457,7 @@ def get_relevant_docs_from_knowledge(
if self.knowledge is None:
return None

# TODO: add async support
relevant_docs: List[Document] = self.knowledge.search(query=query, num_documents=num_documents, **kwargs)
if len(relevant_docs) == 0:
return None
Expand Down
4 changes: 3 additions & 1 deletion libs/agno/agno/models/azure/ai_foundry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import HttpResponseError
except ImportError:
logger.error("`azure-ai-inference` not installed. Please install it via `pip install azure-ai-inference aiohttp`.")
raise ImportError(
"`azure-ai-inference` not installed. Please install it via `pip install azure-ai-inference aiohttp`."
)


@dataclass
Expand Down
2 changes: 2 additions & 0 deletions libs/agno/agno/models/azure/openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class AzureOpenAI(OpenAILike):
name: str = "AzureOpenAI"
provider: str = "Azure"

supports_structured_outputs: bool = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we are just defining this field and not using it anywhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic in the agent is the important part.

# Update the response_format on the Model
        if self.response_model is not None:
            # This will pass the pydantic model to the model
            if self.structured_outputs and self.model.supports_structured_outputs:
                logger.debug("Setting Model.response_format to Agent.response_model")
                self.model.response_format = self.response_model
                self.model.structured_outputs = True
            else:
                # Otherwise we just want JSON
                self.model.response_format = {"type": "json_object"}
        else:
            self.model.response_format = None
            

If we don't set the flag in the model, then it will default to json mode, which MIGHT work, but the native support should be better. The naming we use is... confusing.


api_key: Optional[str] = None
api_version: Optional[str] = "2024-10-21"
azure_endpoint: Optional[str] = None
Expand Down
3 changes: 3 additions & 0 deletions libs/agno/agno/models/cohere/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def request_kwargs(self) -> Dict[str, Any]:
if self.presence_penalty:
_request_params["presence_penalty"] = self.presence_penalty

if self.response_format:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The response_model should be set at the Agent level. We should make sure it is consistent for each provider

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What goes to the model is only response_format and whether to use structured_output or not. We don't send the field response_model, that is only used in the agent.

_request_params["response_format"] = self.response_format

if self._tools is not None and len(self._tools) > 0:
_request_params["tools"] = self._tools
if self.tool_choice is not None:
Expand Down
50 changes: 29 additions & 21 deletions libs/agno/agno/models/groq/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,6 @@
raise ImportError("`groq` not installed. Please install using `pip install groq`")


def format_message(message: Message) -> Dict[str, Any]:
"""
Format a message into the format expected by Groq.

Args:
message (Message): The message to format.

Returns:
Dict[str, Any]: The formatted message.
"""
if message.role == "user":
if message.images is not None:
message = add_images_to_message(message=message, images=message.images)

return message.serialize_for_model()


@dataclass
class Groq(Model):
"""
Expand Down Expand Up @@ -213,6 +196,31 @@ def to_dict(self) -> Dict[str, Any]:
cleaned_dict = {k: v for k, v in model_dict.items() if v is not None}
return cleaned_dict

def format_message(self, message: Message) -> Dict[str, Any]:
"""
Format a message into the format expected by Groq.

Args:
message (Message): The message to format.

Returns:
Dict[str, Any]: The formatted message.
"""
if (
message.role == "system"
and isinstance(message.content, str)
and self.response_format is not None
and self.response_format.get("type") == "json_object"
):
# This is required by Groq to ensure the model outputs in the correct format
message.content += "\n\nYour output should be in JSON format."

if message.role == "user":
if message.images is not None:
message = add_images_to_message(message=message, images=message.images)

return message.serialize_for_model()

def invoke(self, messages: List[Message]) -> ChatCompletion:
"""
Send a chat completion request to the Groq API.
Expand All @@ -226,7 +234,7 @@ def invoke(self, messages: List[Message]) -> ChatCompletion:
try:
return self.get_client().chat.completions.create(
model=self.id,
messages=[format_message(m) for m in messages], # type: ignore
messages=[self.format_message(m) for m in messages], # type: ignore
**self.request_kwargs,
)
except (APIError, APIConnectionError, APITimeoutError, APIStatusError) as e:
Expand All @@ -249,7 +257,7 @@ async def ainvoke(self, messages: List[Message]) -> ChatCompletion:
try:
return await self.get_async_client().chat.completions.create(
model=self.id,
messages=[format_message(m) for m in messages], # type: ignore
messages=[self.format_message(m) for m in messages], # type: ignore
**self.request_kwargs,
)
except (APIError, APIConnectionError, APITimeoutError, APIStatusError) as e:
Expand All @@ -272,7 +280,7 @@ def invoke_stream(self, messages: List[Message]) -> Iterator[ChatCompletionChunk
try:
return self.get_client().chat.completions.create(
model=self.id,
messages=[format_message(m) for m in messages], # type: ignore
messages=[self.format_message(m) for m in messages], # type: ignore
stream=True,
**self.request_kwargs,
)
Expand All @@ -297,7 +305,7 @@ async def ainvoke_stream(self, messages: List[Message]) -> Any:
try:
stream = await self.get_async_client().chat.completions.create(
model=self.id,
messages=[format_message(m) for m in messages], # type: ignore
messages=[self.format_message(m) for m in messages], # type: ignore
stream=True,
**self.request_kwargs,
)
Expand Down
51 changes: 33 additions & 18 deletions libs/agno/agno/models/mistral/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
try:
from mistralai import CompletionEvent
from mistralai import Mistral as MistralClient
from mistralai.extra.struct_chat import ParsedChatCompletionResponse
from mistralai.models import (
AssistantMessage,
HTTPValidationError,
Expand Down Expand Up @@ -119,6 +120,8 @@ class MistralChat(Model):
name: str = "MistralChat"
provider: str = "Mistral"

supports_structured_outputs: bool = True

# -*- Request parameters
temperature: Optional[float] = None
max_tokens: Optional[int] = None
Expand Down Expand Up @@ -231,7 +234,7 @@ def to_dict(self) -> Dict[str, Any]:
cleaned_dict = {k: v for k, v in _dict.items() if v is not None}
return cleaned_dict

def invoke(self, messages: List[Message]) -> ChatCompletionResponse:
def invoke(self, messages: List[Message]) -> Union[ChatCompletionResponse, ParsedChatCompletionResponse]:
"""
Send a chat completion request to the Mistral model.

Expand All @@ -243,13 +246,20 @@ def invoke(self, messages: List[Message]) -> ChatCompletionResponse:
"""
mistral_messages = _format_messages(messages)
try:
response = self.get_client().chat.complete(
model=self.id,
messages=mistral_messages,
**self.request_kwargs,
)
if response is None:
raise ValueError("Chat completion returned None")
response: Union[ChatCompletionResponse, ParsedChatCompletionResponse]
if self.response_format is not None and self.structured_outputs:
response = self.get_client().chat.parse(
model=self.id,
messages=mistral_messages,
response_format=self.response_format, # type: ignore
**self.request_kwargs,
)
else:
response = self.get_client().chat.complete(
model=self.id,
messages=mistral_messages,
**self.request_kwargs,
)
return response

except HTTPValidationError as e:
Expand All @@ -276,8 +286,6 @@ def invoke_stream(self, messages: List[Message]) -> Iterator[Any]:
messages=mistral_messages,
**self.request_kwargs,
)
if stream is None:
raise ValueError("Chat stream returned None")
return stream
except HTTPValidationError as e:
logger.error(f"HTTPValidationError from Mistral: {e}")
Expand All @@ -286,7 +294,7 @@ def invoke_stream(self, messages: List[Message]) -> Iterator[Any]:
logger.error(f"SDKError from Mistral: {e}")
raise ModelProviderError(e, self.name, self.id) from e

async def ainvoke(self, messages: List[Message]) -> ChatCompletionResponse:
async def ainvoke(self, messages: List[Message]) -> Union[ChatCompletionResponse, ParsedChatCompletionResponse]:
"""
Send an asynchronous chat completion request to the Mistral API.

Expand All @@ -298,13 +306,20 @@ async def ainvoke(self, messages: List[Message]) -> ChatCompletionResponse:
"""
mistral_messages = _format_messages(messages)
try:
response = await self.get_client().chat.complete_async(
model=self.id,
messages=mistral_messages,
**self.request_kwargs,
)
if response is None:
raise ValueError("Chat completion returned None")
response: Union[ChatCompletionResponse, ParsedChatCompletionResponse]
if self.response_format is not None and self.structured_outputs:
response = await self.get_client().chat.parse_async(
model=self.id,
messages=mistral_messages,
response_format=self.response_format, # type: ignore
**self.request_kwargs,
)
else:
response = await self.get_client().chat.complete_async(
model=self.id,
messages=mistral_messages,
**self.request_kwargs,
)
return response
except HTTPValidationError as e:
logger.error(f"HTTPValidationError from Mistral: {e}")
Expand Down
Loading