Skip to content

Commit

Permalink
fix: mistral initialization (#2106)
Browse files Browse the repository at this point in the history
## Description

- **Summary of changes**: Issue initializing mistral that threw the
error :
```
Expected sequence of size 1024 for vector of type float and dimension 1024, observed sequence of length 0
ask
```
on running:
`cookbook/agent_concepts/knowledge/vector_dbs/cassandra_db.py`
  • Loading branch information
pritipsingh authored Feb 13, 2025
1 parent c4ca80d commit 135a5f7
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 72 deletions.
Empty file modified cookbook/scripts/run_cassandra.sh
100644 → 100755
Empty file.
6 changes: 4 additions & 2 deletions libs/agno/agno/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,9 +881,11 @@ def run(
import time

time.sleep(delay)

if last_exception is not None:
raise Exception(f"Failed after {num_attempts} attempts. Last error using {last_exception.model_name}({last_exception.model_id}): {str(last_exception)}")
raise Exception(
f"Failed after {num_attempts} attempts. Last error using {last_exception.model_name}({last_exception.model_id}): {str(last_exception)}"
)
else:
raise Exception(f"Failed after {num_attempts} attempts.")

Expand Down
29 changes: 15 additions & 14 deletions libs/agno/agno/embedder/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,22 @@ class MistralEmbedder(Embedder):

@property
def client(self) -> Mistral:
if self.mistral_client:
return self.mistral_client
if not self.mistral_client:
_client_params: Dict[str, Any] = {}
if self.api_key:
_client_params["api_key"] = self.api_key
if self.endpoint:
_client_params["endpoint"] = self.endpoint
if self.max_retries is not None:
_client_params["max_retries"] = self.max_retries
if self.timeout is not None:
_client_params["timeout"] = self.timeout
if self.client_params:
_client_params.update(self.client_params)

_client_params: Dict[str, Any] = {}
if self.api_key:
_client_params["api_key"] = self.api_key
if self.endpoint:
_client_params["endpoint"] = self.endpoint
if self.max_retries:
_client_params["max_retries"] = self.max_retries
if self.timeout:
_client_params["timeout"] = self.timeout
if self.client_params:
_client_params.update(self.client_params)
return Mistral(**_client_params)
self.mistral_client = Mistral(**_client_params)

return self.mistral_client

def _response(self, text: str) -> EmbeddingResponse:
_request_params: Dict[str, Any] = {
Expand Down
2 changes: 0 additions & 2 deletions libs/agno/agno/models/google/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,8 +638,6 @@ def parse_provider_response(self, response: GenerateContentResponse) -> ModelRes

model_response.tool_calls.append(tool_call)



# Extract usage metadata if present
if hasattr(response, "usage_metadata"):
usage: GenerateContentResponseUsageMetadata = response.usage_metadata
Expand Down
2 changes: 1 addition & 1 deletion libs/agno/agno/models/groq/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from agno.utils.openai import add_images_to_message

try:
from groq import APIConnectionError, APIError, APIStatusError, APITimeoutError
from groq import AsyncGroq as AsyncGroqClient
from groq import Groq as GroqClient
from groq import APIError, APIConnectionError, APITimeoutError, APIStatusError
from groq.types.chat import ChatCompletion
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta, ChoiceDeltaToolCall
except (ModuleNotFoundError, ImportError):
Expand Down
10 changes: 5 additions & 5 deletions libs/agno/agno/vectordb/cassandra/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,19 @@ def doc_exists(self, document: Document) -> bool:
"""Check if a document exists by ID."""
query = f"SELECT COUNT(*) FROM {self.keyspace}.{self.table_name} WHERE row_id = %s"
result = self.session.execute(query, (document.id,))
return result[0].count > 0
return result.one()[0] > 0

def name_exists(self, name: str) -> bool:
"""Check if a document exists by name."""
query = f"SELECT COUNT(*) FROM {self.keyspace}.{self.table_name} WHERE document_name = %s"
query = f"SELECT COUNT(*) FROM {self.keyspace}.{self.table_name} WHERE document_name = %s ALLOW FILTERING"
result = self.session.execute(query, (name,))
return result[0].count > 0
return result.one()[0] > 0

def id_exists(self, id: str) -> bool:
"""Check if a document exists by ID."""
query = f"SELECT COUNT(*) FROM {self.keyspace}.{self.table_name} WHERE row_id = %s"
query = f"SELECT COUNT(*) FROM {self.keyspace}.{self.table_name} WHERE row_id = %s ALLOW FILTERING"
result = self.session.execute(query, (id,))
return result[0].count > 0
return result.one()[0] > 0

def insert(self, documents: List[Document], filters: Optional[Dict[str, Any]] = None) -> None:
logger.debug(f"Cassandra VectorDB : Inserting Documents to the table {self.table_name}")
Expand Down
170 changes: 170 additions & 0 deletions libs/agno/tests/unit/vectordb/test_cassandra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import time
import uuid
from typing import Generator

import pytest
from cassandra.cluster import Cluster, Session

from agno.document import Document
from agno.embedder.mistral import MistralEmbedder


@pytest.fixture(scope="session")
def cassandra_session() -> Generator[Session, None, None]:
"""Create a session-scoped connection to Cassandra."""
# Wait for Cassandra to be ready
max_retries = 5
retry_delay = 2

for attempt in range(max_retries):
try:
cluster = Cluster(["localhost"], port=9042)
session = cluster.connect()
print(f"Successfully connected to Cassandra on attempt {attempt + 1}")
break
except Exception:
if attempt == max_retries - 1:
raise
time.sleep(retry_delay)

# Create test keyspace
keyspace = "test_vectordb"
session.execute(f"""
CREATE KEYSPACE IF NOT EXISTS {keyspace}
WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': '1'}}
""")
session.set_keyspace(keyspace)

yield session

# Cleanup after all tests
session.execute(f"DROP KEYSPACE IF EXISTS {keyspace}")
cluster.shutdown()


@pytest.fixture
def vector_db(cassandra_session):
"""Create a fresh VectorDB instance for each test."""
from agno.vectordb.cassandra import Cassandra

table_name = f"test_vectors_{uuid.uuid4().hex[:8]}"
db = Cassandra(
table_name=table_name, keyspace="test_vectordb", embedder=MistralEmbedder(), session=cassandra_session
)
db.create()

assert db.exists(), "Table was not created successfully"

yield db

# Cleanup after each test
db.drop()


def create_test_documents(num_docs: int = 3) -> list[Document]:
"""Helper function to create test documents."""
return [
Document(
id=f"doc_{i}",
content=f"This is test document {i}",
meta_data={"type": "test", "index": str(i)},
name=f"test_doc_{i}",
)
for i in range(num_docs)
]


def test_initialization(cassandra_session):
"""Test VectorDB initialization."""
from agno.vectordb.cassandra import Cassandra

# Test successful initialization
db = Cassandra(table_name="test_vectors", keyspace="test_vectordb", session=cassandra_session)
assert db.table_name == "test_vectors"
assert db.keyspace == "test_vectordb"

# Test initialization failures
with pytest.raises(ValueError):
Cassandra(table_name="", keyspace="test_vectordb", session=cassandra_session)

with pytest.raises(ValueError):
Cassandra(table_name="test_vectors", keyspace="", session=cassandra_session)

with pytest.raises(ValueError):
Cassandra(table_name="test_vectors", keyspace="test_vectordb", session=None)


def test_insert_and_search(vector_db):
"""Test document insertion and search functionality."""
# Insert test documents
docs = create_test_documents(1)
vector_db.insert(docs)

time.sleep(1)

# Test search functionality
results = vector_db.search("test document", limit=1)
assert len(results) == 1
assert all(isinstance(doc, Document) for doc in results)

# Test vector search
results = vector_db.vector_search("test document 1", limit=2)


def test_document_existence(vector_db):
"""Test document existence checking methods."""
docs = create_test_documents(1)
vector_db.insert(docs)

# Test by document object
assert vector_db.doc_exists(docs[0]) is True

# Test by name
assert vector_db.name_exists("test_doc_0") is True
assert vector_db.name_exists("nonexistent") is False

# Test by ID
assert vector_db.id_exists("doc_0") is True
assert vector_db.id_exists("nonexistent") is False


def test_upsert(vector_db):
"""Test upsert functionality."""
# Initial insert
docs = create_test_documents(1)
vector_db.insert(docs)

# Modify document and upsert
modified_doc = Document(
id=docs[0].id, content="Modified content", meta_data={"type": "modified"}, name=docs[0].name
)
vector_db.upsert([modified_doc])

# Verify modification
results = vector_db.search("Modified content", limit=1)
assert len(results) == 1
assert results[0].content == "Modified content"
assert results[0].meta_data["type"] == "modified"


def test_delete_and_drop(vector_db):
"""Test delete and drop functionality."""
# Insert documents
docs = create_test_documents()
vector_db.insert(docs)

# Test delete
assert vector_db.delete() is True
results = vector_db.search("test document", limit=5)
assert len(results) == 0

# Test drop
vector_db.drop()
assert vector_db.exists() is False


def test_exists(vector_db):
"""Test table existence checking."""
assert vector_db.exists() is True
vector_db.drop()
assert vector_db.exists() is False
Loading

0 comments on commit 135a5f7

Please sign in to comment.