From 3db091040ca527bfecc45f49e898d2025be070ce Mon Sep 17 00:00:00 2001 From: Dirk Brand Date: Thu, 13 Feb 2025 16:33:26 +0200 Subject: [PATCH 1/6] Fix structured output on all models --- .../models/azure/openai/structured_output.py | 8 ++-- cookbook/models/cohere/structured_output.py | 5 +-- cookbook/models/deepseek/structured_output.py | 6 +-- .../models/fireworks/structured_output.py | 6 +-- cookbook/models/groq/structured_output.py | 6 +-- cookbook/models/mistral/structured_output.py | 11 ++--- cookbook/models/ollama/structured_output.py | 12 +++--- cookbook/models/xai/structured_output.py | 42 +++++++++++++++++++ libs/agno/agno/models/azure/ai_foundry.py | 2 +- libs/agno/agno/models/azure/openai_chat.py | 2 + libs/agno/agno/models/cohere/chat.py | 3 ++ libs/agno/agno/models/mistral/mistral.py | 40 ++++++++++++------ libs/agno/agno/models/ollama/chat.py | 7 +--- 13 files changed, 105 insertions(+), 45 deletions(-) create mode 100644 cookbook/models/xai/structured_output.py diff --git a/cookbook/models/azure/openai/structured_output.py b/cookbook/models/azure/openai/structured_output.py index b1274b9c72..aae42b6b1f 100644 --- a/cookbook/models/azure/openai/structured_output.py +++ b/cookbook/models/azure/openai/structured_output.py @@ -29,11 +29,11 @@ class MovieScript(BaseModel): model=AzureOpenAI(id="gpt-4o-mini"), description="You help people write movie scripts.", response_model=MovieScript, - # debug_mode=True, + structured_outputs=True, ) # Get the response in a variable -# run: RunResponse = agent.run("New York") -# pprint(run.content) +run: RunResponse = agent.run("New York") +pprint(run.content) -agent.print_response("New York") +# agent.print_response("New York") diff --git a/cookbook/models/cohere/structured_output.py b/cookbook/models/cohere/structured_output.py index 33ab81dff7..2bc5c3714b 100644 --- a/cookbook/models/cohere/structured_output.py +++ b/cookbook/models/cohere/structured_output.py @@ -33,7 +33,6 @@ class MovieScript(BaseModel): ) # Get the response in a variable -# json_mode_response: RunResponse = json_mode_agent.run("New York") -# pprint(json_mode_response.content) +json_mode_response: RunResponse = json_mode_agent.run("New York") +pprint(json_mode_response.content) -json_mode_agent.print_response("New York") diff --git a/cookbook/models/deepseek/structured_output.py b/cookbook/models/deepseek/structured_output.py index 410f785624..b79041d6e7 100644 --- a/cookbook/models/deepseek/structured_output.py +++ b/cookbook/models/deepseek/structured_output.py @@ -32,7 +32,7 @@ class MovieScript(BaseModel): ) # Get the response in a variable -# json_mode_response: RunResponse = json_mode_agent.run("New York") -# pprint(json_mode_response.content) +json_mode_response: RunResponse = json_mode_agent.run("New York") +pprint(json_mode_response.content) -json_mode_agent.print_response("New York") +# json_mode_agent.print_response("New York") diff --git a/cookbook/models/fireworks/structured_output.py b/cookbook/models/fireworks/structured_output.py index 80b9f70e9b..6a3db2dd80 100644 --- a/cookbook/models/fireworks/structured_output.py +++ b/cookbook/models/fireworks/structured_output.py @@ -33,7 +33,7 @@ class MovieScript(BaseModel): ) # Get the response in a variable -# response: RunResponse = agent.run("New York") -# pprint(json_mode_response.content) +response: RunResponse = agent.run("New York") +pprint(response.content) -agent.print_response("New York") +# agent.print_response("New York") diff --git a/cookbook/models/groq/structured_output.py b/cookbook/models/groq/structured_output.py index 6b1f1bab08..eda0a70826 100644 --- a/cookbook/models/groq/structured_output.py +++ b/cookbook/models/groq/structured_output.py @@ -33,7 +33,7 @@ class MovieScript(BaseModel): ) # Get the response in a variable -# run: RunResponse = json_mode_agent.run("New York") -# pprint(run.content) +run: RunResponse = json_mode_agent.run("New York") +pprint(run.content) -json_mode_agent.print_response("New York") +# json_mode_agent.print_response("New York") diff --git a/cookbook/models/mistral/structured_output.py b/cookbook/models/mistral/structured_output.py index d1dcf3610f..1ddcd3deed 100644 --- a/cookbook/models/mistral/structured_output.py +++ b/cookbook/models/mistral/structured_output.py @@ -27,19 +27,20 @@ class MovieScript(BaseModel): ) -json_mode_agent = Agent( +structured_output_agent = Agent( model=MistralChat( id="mistral-large-latest", ), tools=[DuckDuckGoTools()], description="You help people write movie scripts.", response_model=MovieScript, + structured_outputs=True, show_tool_calls=True, - debug_mode=True, + # debug_mode=True, ) # Get the response in a variable -# json_mode_response: RunResponse = json_mode_agent.run("New York") -# pprint(json_mode_response.content) +structured_output_response: RunResponse = structured_output_agent.run("New York") +pprint(structured_output_response.content) -json_mode_agent.print_response("Find a cool movie idea about London and write it.") +# json_mode_agent.print_response("Find a cool movie idea about London and write it.") diff --git a/cookbook/models/ollama/structured_output.py b/cookbook/models/ollama/structured_output.py index 8c525cdc9f..61ef951427 100644 --- a/cookbook/models/ollama/structured_output.py +++ b/cookbook/models/ollama/structured_output.py @@ -4,6 +4,9 @@ from agno.agent import Agent from agno.models.ollama import Ollama from pydantic import BaseModel, Field +from rich.pretty import pprint # noqa + +from agno.run.response import RunResponse class MovieScript(BaseModel): @@ -34,12 +37,9 @@ class MovieScript(BaseModel): ) # Run the agent synchronously -structured_output_agent.print_response("Llamas ruling the world") +structured_output_response: RunResponse = structured_output_agent.run("Llamas ruling the world") +pprint(structured_output_response.content) # Run the agent asynchronously -async def run_agents_async(): - await structured_output_agent.aprint_response("Llamas ruling the world") - - -asyncio.run(run_agents_async()) +asyncio.run(structured_output_agent.aprint_response("Llamas ruling the world")) diff --git a/cookbook/models/xai/structured_output.py b/cookbook/models/xai/structured_output.py new file mode 100644 index 0000000000..15c983f3a7 --- /dev/null +++ b/cookbook/models/xai/structured_output.py @@ -0,0 +1,42 @@ +import asyncio +from typing import List + +from agno.agent import Agent +from pydantic import BaseModel, Field +from rich.pretty import pprint # noqa + +from agno.models.xai.xai import xAI +from agno.run.response import RunResponse + + +class MovieScript(BaseModel): + name: str = Field(..., description="Give a name to this movie") + setting: str = Field( + ..., description="Provide a nice setting for a blockbuster movie." + ) + ending: str = Field( + ..., + description="Ending of the movie. If not available, provide a happy ending.", + ) + genre: str = Field( + ..., + description="Genre of the movie. If not available, select action, thriller or romantic comedy.", + ) + characters: List[str] = Field(..., description="Name of characters for this movie.") + storyline: str = Field( + ..., description="3 sentence storyline for the movie. Make it exciting!" + ) + + +# Agent that returns a structured output +structured_output_agent = Agent( + model=xAI(id="grok-2-latest"), + description="You write movie scripts.", + response_model=MovieScript, + structured_outputs=True, +) + +# Run the agent synchronously +structured_output_response: RunResponse = structured_output_agent.run("Llamas ruling the world") +pprint(structured_output_response.content) + diff --git a/libs/agno/agno/models/azure/ai_foundry.py b/libs/agno/agno/models/azure/ai_foundry.py index 6b7cfa88d0..3e7ff66b27 100644 --- a/libs/agno/agno/models/azure/ai_foundry.py +++ b/libs/agno/agno/models/azure/ai_foundry.py @@ -27,7 +27,7 @@ from azure.core.credentials import AzureKeyCredential from azure.core.exceptions import HttpResponseError except ImportError: - logger.error("`azure-ai-inference` not installed. Please install it via `pip install azure-ai-inference aiohttp`.") + raise ImportError("`azure-ai-inference` not installed. Please install it via `pip install azure-ai-inference aiohttp`.") @dataclass diff --git a/libs/agno/agno/models/azure/openai_chat.py b/libs/agno/agno/models/azure/openai_chat.py index c0cc39c8b6..dd06f02833 100644 --- a/libs/agno/agno/models/azure/openai_chat.py +++ b/libs/agno/agno/models/azure/openai_chat.py @@ -38,6 +38,8 @@ class AzureOpenAI(OpenAILike): name: str = "AzureOpenAI" provider: str = "Azure" + supports_structured_outputs: bool = True + api_key: Optional[str] = None api_version: Optional[str] = "2024-10-21" azure_endpoint: Optional[str] = None diff --git a/libs/agno/agno/models/cohere/chat.py b/libs/agno/agno/models/cohere/chat.py index 7baffaa2c9..adae70575b 100644 --- a/libs/agno/agno/models/cohere/chat.py +++ b/libs/agno/agno/models/cohere/chat.py @@ -100,6 +100,9 @@ def request_kwargs(self) -> Dict[str, Any]: if self.presence_penalty: _request_params["presence_penalty"] = self.presence_penalty + if self.response_format: + _request_params["response_format"] = self.response_format + if self._tools is not None and len(self._tools) > 0: _request_params["tools"] = self._tools if self.tool_choice is not None: diff --git a/libs/agno/agno/models/mistral/mistral.py b/libs/agno/agno/models/mistral/mistral.py index 42c6f23527..9cbfb6aeb5 100644 --- a/libs/agno/agno/models/mistral/mistral.py +++ b/libs/agno/agno/models/mistral/mistral.py @@ -119,6 +119,8 @@ class MistralChat(Model): name: str = "MistralChat" provider: str = "Mistral" + supports_structured_outputs: bool = True + # -*- Request parameters temperature: Optional[float] = None max_tokens: Optional[int] = None @@ -243,11 +245,19 @@ def invoke(self, messages: List[Message]) -> ChatCompletionResponse: """ mistral_messages = _format_messages(messages) try: - response = self.get_client().chat.complete( - model=self.id, - messages=mistral_messages, - **self.request_kwargs, - ) + if self.response_format is not None and self.structured_outputs: + response = self.get_client().chat.parse( + model=self.id, + messages=mistral_messages, + response_format=self.response_format, + **self.request_kwargs, + ) + else: + response = self.get_client().chat.complete( + model=self.id, + messages=mistral_messages, + **self.request_kwargs, + ) if response is None: raise ValueError("Chat completion returned None") return response @@ -298,13 +308,19 @@ async def ainvoke(self, messages: List[Message]) -> ChatCompletionResponse: """ mistral_messages = _format_messages(messages) try: - response = await self.get_client().chat.complete_async( - model=self.id, - messages=mistral_messages, - **self.request_kwargs, - ) - if response is None: - raise ValueError("Chat completion returned None") + if self.response_format is not None and self.structured_outputs: + response = await self.get_client().chat.parse_async( + model=self.id, + messages=mistral_messages, + response_format=self.response_format, + **self.request_kwargs, + ) + else: + response = await self.get_client().chat.complete_async( + model=self.id, + messages=mistral_messages, + **self.request_kwargs, + ) return response except HTTPValidationError as e: logger.error(f"HTTPValidationError from Mistral: {e}") diff --git a/libs/agno/agno/models/ollama/chat.py b/libs/agno/agno/models/ollama/chat.py index 0e89649c78..b193e72b49 100644 --- a/libs/agno/agno/models/ollama/chat.py +++ b/libs/agno/agno/models/ollama/chat.py @@ -40,6 +40,7 @@ class Ollama(Model): id: str = "llama3.1" name: str = "Ollama" provider: str = "Ollama" + supports_structured_outputs: bool = True # Request parameters @@ -57,10 +58,6 @@ class Ollama(Model): client: Optional[OllamaClient] = None async_client: Optional[AsyncOllamaClient] = None - # Internal parameters. Not used for API requests - # Whether to use the structured outputs with this Model. - structured_outputs: bool = False - def _get_client_params(self) -> Dict[str, Any]: base_params = { "host": self.host, @@ -281,7 +278,7 @@ def parse_provider_response(self, response: ChatResponse) -> ModelResponse: and self.structured_outputs and issubclass(self.response_format, BaseModel) ): - parsed_object = response_message.parsed # type: ignore + parsed_object = response_message.content # type: ignore if parsed_object is not None: model_response.parsed = parsed_object except Exception as e: From 5c67509333c7bbe24289f6206e6d80fb2cbd2fce Mon Sep 17 00:00:00 2001 From: Dirk Brand Date: Thu, 13 Feb 2025 16:38:47 +0200 Subject: [PATCH 2/6] Fix style --- cookbook/models/cohere/structured_output.py | 1 - cookbook/models/ollama/structured_output.py | 7 +- cookbook/models/xai/structured_output.py | 10 +-- libs/agno/agno/agent/agent.py | 6 +- libs/agno/agno/models/azure/ai_foundry.py | 4 +- libs/agno/agno/models/google/gemini.py | 2 - libs/agno/agno/models/groq/groq.py | 2 +- libs/agno/agno/models/mistral/mistral.py | 15 ++-- libs/agno/agno/vectordb/chroma/chromadb.py | 8 +- .../agno/tests/unit/vectordb/test_chromadb.py | 88 +++++++++---------- 10 files changed, 68 insertions(+), 75 deletions(-) diff --git a/cookbook/models/cohere/structured_output.py b/cookbook/models/cohere/structured_output.py index 2bc5c3714b..a6b5deea76 100644 --- a/cookbook/models/cohere/structured_output.py +++ b/cookbook/models/cohere/structured_output.py @@ -35,4 +35,3 @@ class MovieScript(BaseModel): # Get the response in a variable json_mode_response: RunResponse = json_mode_agent.run("New York") pprint(json_mode_response.content) - diff --git a/cookbook/models/ollama/structured_output.py b/cookbook/models/ollama/structured_output.py index 61ef951427..19f937e580 100644 --- a/cookbook/models/ollama/structured_output.py +++ b/cookbook/models/ollama/structured_output.py @@ -3,11 +3,10 @@ from agno.agent import Agent from agno.models.ollama import Ollama +from agno.run.response import RunResponse from pydantic import BaseModel, Field from rich.pretty import pprint # noqa -from agno.run.response import RunResponse - class MovieScript(BaseModel): name: str = Field(..., description="Give a name to this movie") @@ -37,7 +36,9 @@ class MovieScript(BaseModel): ) # Run the agent synchronously -structured_output_response: RunResponse = structured_output_agent.run("Llamas ruling the world") +structured_output_response: RunResponse = structured_output_agent.run( + "Llamas ruling the world" +) pprint(structured_output_response.content) diff --git a/cookbook/models/xai/structured_output.py b/cookbook/models/xai/structured_output.py index 15c983f3a7..41d4c41ae5 100644 --- a/cookbook/models/xai/structured_output.py +++ b/cookbook/models/xai/structured_output.py @@ -2,11 +2,10 @@ from typing import List from agno.agent import Agent -from pydantic import BaseModel, Field -from rich.pretty import pprint # noqa - from agno.models.xai.xai import xAI from agno.run.response import RunResponse +from pydantic import BaseModel, Field +from rich.pretty import pprint # noqa class MovieScript(BaseModel): @@ -37,6 +36,7 @@ class MovieScript(BaseModel): ) # Run the agent synchronously -structured_output_response: RunResponse = structured_output_agent.run("Llamas ruling the world") +structured_output_response: RunResponse = structured_output_agent.run( + "Llamas ruling the world" +) pprint(structured_output_response.content) - diff --git a/libs/agno/agno/agent/agent.py b/libs/agno/agno/agent/agent.py index e71692887c..43a7dc6bc4 100644 --- a/libs/agno/agno/agent/agent.py +++ b/libs/agno/agno/agent/agent.py @@ -881,9 +881,11 @@ def run( import time time.sleep(delay) - + if last_exception is not None: - raise Exception(f"Failed after {num_attempts} attempts. Last error using {last_exception.model_name}({last_exception.model_id}): {str(last_exception)}") + raise Exception( + f"Failed after {num_attempts} attempts. Last error using {last_exception.model_name}({last_exception.model_id}): {str(last_exception)}" + ) else: raise Exception(f"Failed after {num_attempts} attempts.") diff --git a/libs/agno/agno/models/azure/ai_foundry.py b/libs/agno/agno/models/azure/ai_foundry.py index 3e7ff66b27..b4b23952d4 100644 --- a/libs/agno/agno/models/azure/ai_foundry.py +++ b/libs/agno/agno/models/azure/ai_foundry.py @@ -27,7 +27,9 @@ from azure.core.credentials import AzureKeyCredential from azure.core.exceptions import HttpResponseError except ImportError: - raise ImportError("`azure-ai-inference` not installed. Please install it via `pip install azure-ai-inference aiohttp`.") + raise ImportError( + "`azure-ai-inference` not installed. Please install it via `pip install azure-ai-inference aiohttp`." + ) @dataclass diff --git a/libs/agno/agno/models/google/gemini.py b/libs/agno/agno/models/google/gemini.py index 9503f26e1b..16f5544997 100644 --- a/libs/agno/agno/models/google/gemini.py +++ b/libs/agno/agno/models/google/gemini.py @@ -638,8 +638,6 @@ def parse_provider_response(self, response: GenerateContentResponse) -> ModelRes model_response.tool_calls.append(tool_call) - - # Extract usage metadata if present if hasattr(response, "usage_metadata"): usage: GenerateContentResponseUsageMetadata = response.usage_metadata diff --git a/libs/agno/agno/models/groq/groq.py b/libs/agno/agno/models/groq/groq.py index 1c7ab0677d..9d703a8dca 100644 --- a/libs/agno/agno/models/groq/groq.py +++ b/libs/agno/agno/models/groq/groq.py @@ -12,9 +12,9 @@ from agno.utils.openai import add_images_to_message try: + from groq import APIConnectionError, APIError, APIStatusError, APITimeoutError from groq import AsyncGroq as AsyncGroqClient from groq import Groq as GroqClient - from groq import APIError, APIConnectionError, APITimeoutError, APIStatusError from groq.types.chat import ChatCompletion from groq.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta, ChoiceDeltaToolCall except (ModuleNotFoundError, ImportError): diff --git a/libs/agno/agno/models/mistral/mistral.py b/libs/agno/agno/models/mistral/mistral.py index 9cbfb6aeb5..66c73b9db2 100644 --- a/libs/agno/agno/models/mistral/mistral.py +++ b/libs/agno/agno/models/mistral/mistral.py @@ -12,6 +12,7 @@ try: from mistralai import CompletionEvent from mistralai import Mistral as MistralClient + from mistralai.extra.struct_chat import ParsedChatCompletionResponse from mistralai.models import ( AssistantMessage, HTTPValidationError, @@ -233,7 +234,7 @@ def to_dict(self) -> Dict[str, Any]: cleaned_dict = {k: v for k, v in _dict.items() if v is not None} return cleaned_dict - def invoke(self, messages: List[Message]) -> ChatCompletionResponse: + def invoke(self, messages: List[Message]) -> Union[ChatCompletionResponse, ParsedChatCompletionResponse]: """ Send a chat completion request to the Mistral model. @@ -245,11 +246,12 @@ def invoke(self, messages: List[Message]) -> ChatCompletionResponse: """ mistral_messages = _format_messages(messages) try: + response: Union[ChatCompletionResponse, ParsedChatCompletionResponse] if self.response_format is not None and self.structured_outputs: response = self.get_client().chat.parse( model=self.id, messages=mistral_messages, - response_format=self.response_format, + response_format=self.response_format, # type: ignore **self.request_kwargs, ) else: @@ -258,8 +260,6 @@ def invoke(self, messages: List[Message]) -> ChatCompletionResponse: messages=mistral_messages, **self.request_kwargs, ) - if response is None: - raise ValueError("Chat completion returned None") return response except HTTPValidationError as e: @@ -286,8 +286,6 @@ def invoke_stream(self, messages: List[Message]) -> Iterator[Any]: messages=mistral_messages, **self.request_kwargs, ) - if stream is None: - raise ValueError("Chat stream returned None") return stream except HTTPValidationError as e: logger.error(f"HTTPValidationError from Mistral: {e}") @@ -296,7 +294,7 @@ def invoke_stream(self, messages: List[Message]) -> Iterator[Any]: logger.error(f"SDKError from Mistral: {e}") raise ModelProviderError(e, self.name, self.id) from e - async def ainvoke(self, messages: List[Message]) -> ChatCompletionResponse: + async def ainvoke(self, messages: List[Message]) -> Union[ChatCompletionResponse, ParsedChatCompletionResponse]: """ Send an asynchronous chat completion request to the Mistral API. @@ -308,11 +306,12 @@ async def ainvoke(self, messages: List[Message]) -> ChatCompletionResponse: """ mistral_messages = _format_messages(messages) try: + response: Union[ChatCompletionResponse, ParsedChatCompletionResponse] if self.response_format is not None and self.structured_outputs: response = await self.get_client().chat.parse_async( model=self.id, messages=mistral_messages, - response_format=self.response_format, + response_format=self.response_format, # type: ignore **self.request_kwargs, ) else: diff --git a/libs/agno/agno/vectordb/chroma/chromadb.py b/libs/agno/agno/vectordb/chroma/chromadb.py index 572db452d2..e044c346f6 100644 --- a/libs/agno/agno/vectordb/chroma/chromadb.py +++ b/libs/agno/agno/vectordb/chroma/chromadb.py @@ -207,7 +207,7 @@ def search(self, query: str, limit: int = 5, filters: Optional[Dict[str, Any]] = result: QueryResult = self._collection.query( query_embeddings=query_embedding, n_results=limit, - include=["metadatas", "documents", "embeddings", "distances", "uris"], + include=["metadatas", "documents", "embeddings", "distances", "uris"], # type: ignore ) # Build search results @@ -216,12 +216,12 @@ def search(self, query: str, limit: int = 5, filters: Optional[Dict[str, Any]] = ids = result.get("ids", [[]])[0] metadata = result.get("metadatas", [{}])[0] # type: ignore documents = result.get("documents", [[]])[0] # type: ignore - embeddings = result.get("embeddings")[0] - embeddings = [e.tolist() if hasattr(e, "tolist") else e for e in embeddings] + embeddings = result.get("embeddings")[0] # type: ignore + embeddings = [e.tolist() if hasattr(e, "tolist") else e for e in embeddings] # type: ignore distances = result.get("distances", [[]])[0] # type: ignore for idx, distance in enumerate(distances): - metadata[idx]["distances"] = distance + metadata[idx]["distances"] = distance # type: ignore try: # Use zip to iterate over multiple lists simultaneously diff --git a/libs/agno/tests/unit/vectordb/test_chromadb.py b/libs/agno/tests/unit/vectordb/test_chromadb.py index c38bb3c323..b13690aa74 100644 --- a/libs/agno/tests/unit/vectordb/test_chromadb.py +++ b/libs/agno/tests/unit/vectordb/test_chromadb.py @@ -1,81 +1,80 @@ -import pytest -from typing import List import os import shutil +from typing import List + +import pytest -from agno.vectordb.chroma import ChromaDb from agno.document import Document from agno.embedder.openai import OpenAIEmbedder +from agno.vectordb.chroma import ChromaDb from agno.vectordb.distance import Distance TEST_COLLECTION = "test_collection" TEST_PATH = "tmp/test_chromadb" + @pytest.fixture def chroma_db(): """Fixture to create and clean up a ChromaDb instance""" # Ensure the test directory exists with proper permissions os.makedirs(TEST_PATH, exist_ok=True) - + # Clean up any existing data before the test if os.path.exists(TEST_PATH): shutil.rmtree(TEST_PATH) os.makedirs(TEST_PATH) - - db = ChromaDb( - collection=TEST_COLLECTION, - path=TEST_PATH, - persistent_client=False - ) + + db = ChromaDb(collection=TEST_COLLECTION, path=TEST_PATH, persistent_client=False) db.create() yield db - + # Cleanup after test try: db.drop() except Exception: pass - + if os.path.exists(TEST_PATH): shutil.rmtree(TEST_PATH) + @pytest.fixture def sample_documents() -> List[Document]: """Fixture to create sample documents""" return [ Document( - content="Tom Kha Gai is a Thai coconut soup with chicken", - meta_data={"cuisine": "Thai", "type": "soup"} - ), - Document( - content="Pad Thai is a stir-fried rice noodle dish", - meta_data={"cuisine": "Thai", "type": "noodles"} + content="Tom Kha Gai is a Thai coconut soup with chicken", meta_data={"cuisine": "Thai", "type": "soup"} ), + Document(content="Pad Thai is a stir-fried rice noodle dish", meta_data={"cuisine": "Thai", "type": "noodles"}), Document( content="Green curry is a spicy Thai curry with coconut milk", - meta_data={"cuisine": "Thai", "type": "curry"} - ) + meta_data={"cuisine": "Thai", "type": "curry"}, + ), ] + def test_create_collection(chroma_db): """Test creating a collection""" assert chroma_db.exists() is True assert chroma_db.get_count() == 0 + def test_insert_documents(chroma_db, sample_documents): """Test inserting documents""" chroma_db.insert(sample_documents) assert chroma_db.get_count() == 3 + def test_search_documents(chroma_db, sample_documents): """Test searching documents""" chroma_db.insert(sample_documents) - + # Search for coconut-related dishes results = chroma_db.search("coconut dishes", limit=2) assert len(results) == 2 assert any("coconut" in doc.content.lower() for doc in results) + def test_upsert_documents(chroma_db, sample_documents): """Test upserting documents""" # Initial insert @@ -84,46 +83,39 @@ def test_upsert_documents(chroma_db, sample_documents): # Upsert same document with different content modified_doc = Document( - content="Tom Kha Gai is a spicy and sour Thai coconut soup", - meta_data={"cuisine": "Thai", "type": "soup"} + content="Tom Kha Gai is a spicy and sour Thai coconut soup", meta_data={"cuisine": "Thai", "type": "soup"} ) chroma_db.upsert([modified_doc]) - + # Search to verify the update results = chroma_db.search("spicy and sour", limit=1) assert len(results) == 1 assert "spicy and sour" in results[0].content + def test_delete_collection(chroma_db, sample_documents): """Test deleting collection""" chroma_db.insert(sample_documents) assert chroma_db.get_count() == 3 - + assert chroma_db.delete() is True assert chroma_db.exists() is False + def test_distance_metrics(): """Test different distance metrics""" # Ensure the test directory exists os.makedirs(TEST_PATH, exist_ok=True) - - db_cosine = ChromaDb( - collection="test_cosine", - path=TEST_PATH, - distance=Distance.cosine - ) + + db_cosine = ChromaDb(collection="test_cosine", path=TEST_PATH, distance=Distance.cosine) db_cosine.create() - - db_euclidean = ChromaDb( - collection="test_euclidean", - path=TEST_PATH, - distance=Distance.l2 - ) + + db_euclidean = ChromaDb(collection="test_euclidean", path=TEST_PATH, distance=Distance.l2) db_euclidean.create() - + assert db_cosine._collection is not None assert db_euclidean._collection is not None - + # Cleanup try: db_cosine.drop() @@ -132,45 +124,45 @@ def test_distance_metrics(): if os.path.exists(TEST_PATH): shutil.rmtree(TEST_PATH) + def test_doc_exists(chroma_db, sample_documents): """Test document existence check""" chroma_db.insert([sample_documents[0]]) assert chroma_db.doc_exists(sample_documents[0]) is True + def test_get_count(chroma_db, sample_documents): """Test document count""" assert chroma_db.get_count() == 0 chroma_db.insert(sample_documents) assert chroma_db.get_count() == 3 + @pytest.mark.asyncio async def test_error_handling(chroma_db): """Test error handling scenarios""" # Test search with invalid query results = chroma_db.search("") assert len(results) == 0 - + # Test inserting empty document list chroma_db.insert([]) assert chroma_db.get_count() == 0 + def test_custom_embedder(): """Test using a custom embedder""" # Ensure the test directory exists os.makedirs(TEST_PATH, exist_ok=True) - + custom_embedder = OpenAIEmbedder() - db = ChromaDb( - collection=TEST_COLLECTION, - path=TEST_PATH, - embedder=custom_embedder - ) + db = ChromaDb(collection=TEST_COLLECTION, path=TEST_PATH, embedder=custom_embedder) db.create() assert db.embedder == custom_embedder - + # Cleanup try: db.drop() finally: if os.path.exists(TEST_PATH): - shutil.rmtree(TEST_PATH) \ No newline at end of file + shutil.rmtree(TEST_PATH) From 704e033836d2fe3b1438e2d613387059a0558037 Mon Sep 17 00:00:00 2001 From: Dirk Brand Date: Thu, 13 Feb 2025 17:23:56 +0200 Subject: [PATCH 3/6] update --- libs/agno/agno/agent/agent.py | 1 + libs/agno/agno/models/groq/groq.py | 44 ++++++++++++++++-------------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/libs/agno/agno/agent/agent.py b/libs/agno/agno/agent/agent.py index 43a7dc6bc4..722b316e41 100644 --- a/libs/agno/agno/agent/agent.py +++ b/libs/agno/agno/agent/agent.py @@ -2457,6 +2457,7 @@ def get_relevant_docs_from_knowledge( if self.knowledge is None: return None + # TODO: add async support relevant_docs: List[Document] = self.knowledge.search(query=query, num_documents=num_documents, **kwargs) if len(relevant_docs) == 0: return None diff --git a/libs/agno/agno/models/groq/groq.py b/libs/agno/agno/models/groq/groq.py index 9d703a8dca..6c6de4bee5 100644 --- a/libs/agno/agno/models/groq/groq.py +++ b/libs/agno/agno/models/groq/groq.py @@ -21,22 +21,6 @@ raise ImportError("`groq` not installed. Please install using `pip install groq`") -def format_message(message: Message) -> Dict[str, Any]: - """ - Format a message into the format expected by Groq. - - Args: - message (Message): The message to format. - - Returns: - Dict[str, Any]: The formatted message. - """ - if message.role == "user": - if message.images is not None: - message = add_images_to_message(message=message, images=message.images) - - return message.serialize_for_model() - @dataclass class Groq(Model): @@ -213,6 +197,26 @@ def to_dict(self) -> Dict[str, Any]: cleaned_dict = {k: v for k, v in model_dict.items() if v is not None} return cleaned_dict + def format_message(self, message: Message) -> Dict[str, Any]: + """ + Format a message into the format expected by Groq. + + Args: + message (Message): The message to format. + + Returns: + Dict[str, Any]: The formatted message. + """ + if message.role == "system" and self.response_format is not None and self.response_format.get("type") == "json_object": + # This is required by Groq to ensure the model outputs in the correct format + message.content += "\n\nYour output should be in JSON format." + + if message.role == "user": + if message.images is not None: + message = add_images_to_message(message=message, images=message.images) + + return message.serialize_for_model() + def invoke(self, messages: List[Message]) -> ChatCompletion: """ Send a chat completion request to the Groq API. @@ -226,7 +230,7 @@ def invoke(self, messages: List[Message]) -> ChatCompletion: try: return self.get_client().chat.completions.create( model=self.id, - messages=[format_message(m) for m in messages], # type: ignore + messages=[self.format_message(m) for m in messages], # type: ignore **self.request_kwargs, ) except (APIError, APIConnectionError, APITimeoutError, APIStatusError) as e: @@ -249,7 +253,7 @@ async def ainvoke(self, messages: List[Message]) -> ChatCompletion: try: return await self.get_async_client().chat.completions.create( model=self.id, - messages=[format_message(m) for m in messages], # type: ignore + messages=[self.format_message(m) for m in messages], # type: ignore **self.request_kwargs, ) except (APIError, APIConnectionError, APITimeoutError, APIStatusError) as e: @@ -272,7 +276,7 @@ def invoke_stream(self, messages: List[Message]) -> Iterator[ChatCompletionChunk try: return self.get_client().chat.completions.create( model=self.id, - messages=[format_message(m) for m in messages], # type: ignore + messages=[self.format_message(m) for m in messages], # type: ignore stream=True, **self.request_kwargs, ) @@ -297,7 +301,7 @@ async def ainvoke_stream(self, messages: List[Message]) -> Any: try: stream = await self.get_async_client().chat.completions.create( model=self.id, - messages=[format_message(m) for m in messages], # type: ignore + messages=[self.format_message(m) for m in messages], # type: ignore stream=True, **self.request_kwargs, ) From 05f8e952db176cd60f3073265f68e7ffbe58d81a Mon Sep 17 00:00:00 2001 From: Dirk Brand Date: Thu, 13 Feb 2025 17:28:22 +0200 Subject: [PATCH 4/6] Update --- libs/agno/agno/models/groq/groq.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/libs/agno/agno/models/groq/groq.py b/libs/agno/agno/models/groq/groq.py index 6c6de4bee5..9243840584 100644 --- a/libs/agno/agno/models/groq/groq.py +++ b/libs/agno/agno/models/groq/groq.py @@ -21,7 +21,6 @@ raise ImportError("`groq` not installed. Please install using `pip install groq`") - @dataclass class Groq(Model): """ @@ -207,7 +206,12 @@ def format_message(self, message: Message) -> Dict[str, Any]: Returns: Dict[str, Any]: The formatted message. """ - if message.role == "system" and self.response_format is not None and self.response_format.get("type") == "json_object": + if ( + message.role == "system" + and isinstance(message.content, str) + and self.response_format is not None + and self.response_format.get("type") == "json_object" + ): # This is required by Groq to ensure the model outputs in the correct format message.content += "\n\nYour output should be in JSON format." From 9355525d0992b1ffb906c337b5ce453a0d855f21 Mon Sep 17 00:00:00 2001 From: Dirk Brand Date: Thu, 13 Feb 2025 17:43:02 +0200 Subject: [PATCH 5/6] Remove tool choice from Ollama --- libs/agno/agno/models/ollama/chat.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/libs/agno/agno/models/ollama/chat.py b/libs/agno/agno/models/ollama/chat.py index b193e72b49..317b1198b6 100644 --- a/libs/agno/agno/models/ollama/chat.py +++ b/libs/agno/agno/models/ollama/chat.py @@ -115,8 +115,6 @@ def request_kwargs(self) -> Dict[str, Any]: # Add tools if self._tools is not None and len(self._tools) > 0: request_params["tools"] = self._tools - if self.tool_choice is not None: - request_params["tool_choice"] = self.tool_choice # Add additional request params if provided if self.request_params: request_params.update(self.request_params) @@ -140,10 +138,6 @@ def to_dict(self) -> Dict[str, Any]: ) if self._tools is not None: model_dict["tools"] = self._tools - if self.tool_choice is not None: - model_dict["tool_choice"] = self.tool_choice - else: - model_dict["tool_choice"] = "auto" cleaned_dict = {k: v for k, v in model_dict.items() if v is not None} return cleaned_dict From 770e19578a4d5669c7c51c9004d176dcbc499efd Mon Sep 17 00:00:00 2001 From: Dirk Brand Date: Wed, 19 Feb 2025 16:43:17 +0200 Subject: [PATCH 6/6] Update --- cookbook/models/ollama/structured_output.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cookbook/models/ollama/structured_output.py b/cookbook/models/ollama/structured_output.py index 19f937e580..570d8c6341 100644 --- a/cookbook/models/ollama/structured_output.py +++ b/cookbook/models/ollama/structured_output.py @@ -37,10 +37,10 @@ class MovieScript(BaseModel): # Run the agent synchronously structured_output_response: RunResponse = structured_output_agent.run( - "Llamas ruling the world" + "New York" ) pprint(structured_output_response.content) # Run the agent asynchronously -asyncio.run(structured_output_agent.aprint_response("Llamas ruling the world")) +asyncio.run(structured_output_agent.aprint_response("New York"))