Skip to content

Commit

Permalink
Merge
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkbrnd committed Feb 13, 2025
2 parents 7fd3369 + 135a5f7 commit 53d1a77
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 19 deletions.
Empty file modified cookbook/scripts/run_cassandra.sh
100644 → 100755
Empty file.
29 changes: 15 additions & 14 deletions libs/agno/agno/embedder/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,22 @@ class MistralEmbedder(Embedder):

@property
def client(self) -> Mistral:
if self.mistral_client:
return self.mistral_client
if not self.mistral_client:
_client_params: Dict[str, Any] = {}
if self.api_key:
_client_params["api_key"] = self.api_key
if self.endpoint:
_client_params["endpoint"] = self.endpoint
if self.max_retries is not None:
_client_params["max_retries"] = self.max_retries
if self.timeout is not None:
_client_params["timeout"] = self.timeout
if self.client_params:
_client_params.update(self.client_params)

_client_params: Dict[str, Any] = {}
if self.api_key:
_client_params["api_key"] = self.api_key
if self.endpoint:
_client_params["endpoint"] = self.endpoint
if self.max_retries:
_client_params["max_retries"] = self.max_retries
if self.timeout:
_client_params["timeout"] = self.timeout
if self.client_params:
_client_params.update(self.client_params)
return Mistral(**_client_params)
self.mistral_client = Mistral(**_client_params)

return self.mistral_client

def _response(self, text: str) -> EmbeddingResponse:
_request_params: Dict[str, Any] = {
Expand Down
10 changes: 5 additions & 5 deletions libs/agno/agno/vectordb/cassandra/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,19 @@ def doc_exists(self, document: Document) -> bool:
"""Check if a document exists by ID."""
query = f"SELECT COUNT(*) FROM {self.keyspace}.{self.table_name} WHERE row_id = %s"
result = self.session.execute(query, (document.id,))
return result[0].count > 0
return result.one()[0] > 0

def name_exists(self, name: str) -> bool:
"""Check if a document exists by name."""
query = f"SELECT COUNT(*) FROM {self.keyspace}.{self.table_name} WHERE document_name = %s"
query = f"SELECT COUNT(*) FROM {self.keyspace}.{self.table_name} WHERE document_name = %s ALLOW FILTERING"
result = self.session.execute(query, (name,))
return result[0].count > 0
return result.one()[0] > 0

def id_exists(self, id: str) -> bool:
"""Check if a document exists by ID."""
query = f"SELECT COUNT(*) FROM {self.keyspace}.{self.table_name} WHERE row_id = %s"
query = f"SELECT COUNT(*) FROM {self.keyspace}.{self.table_name} WHERE row_id = %s ALLOW FILTERING"
result = self.session.execute(query, (id,))
return result[0].count > 0
return result.one()[0] > 0

def insert(self, documents: List[Document], filters: Optional[Dict[str, Any]] = None) -> None:
logger.debug(f"Cassandra VectorDB : Inserting Documents to the table {self.table_name}")
Expand Down
170 changes: 170 additions & 0 deletions libs/agno/tests/unit/vectordb/test_cassandra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import time
import uuid
from typing import Generator

import pytest
from cassandra.cluster import Cluster, Session

from agno.document import Document
from agno.embedder.mistral import MistralEmbedder


@pytest.fixture(scope="session")
def cassandra_session() -> Generator[Session, None, None]:
"""Create a session-scoped connection to Cassandra."""
# Wait for Cassandra to be ready
max_retries = 5
retry_delay = 2

for attempt in range(max_retries):
try:
cluster = Cluster(["localhost"], port=9042)
session = cluster.connect()
print(f"Successfully connected to Cassandra on attempt {attempt + 1}")
break
except Exception:
if attempt == max_retries - 1:
raise
time.sleep(retry_delay)

# Create test keyspace
keyspace = "test_vectordb"
session.execute(f"""
CREATE KEYSPACE IF NOT EXISTS {keyspace}
WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': '1'}}
""")
session.set_keyspace(keyspace)

yield session

# Cleanup after all tests
session.execute(f"DROP KEYSPACE IF EXISTS {keyspace}")
cluster.shutdown()


@pytest.fixture
def vector_db(cassandra_session):
"""Create a fresh VectorDB instance for each test."""
from agno.vectordb.cassandra import Cassandra

table_name = f"test_vectors_{uuid.uuid4().hex[:8]}"
db = Cassandra(
table_name=table_name, keyspace="test_vectordb", embedder=MistralEmbedder(), session=cassandra_session
)
db.create()

assert db.exists(), "Table was not created successfully"

yield db

# Cleanup after each test
db.drop()


def create_test_documents(num_docs: int = 3) -> list[Document]:
"""Helper function to create test documents."""
return [
Document(
id=f"doc_{i}",
content=f"This is test document {i}",
meta_data={"type": "test", "index": str(i)},
name=f"test_doc_{i}",
)
for i in range(num_docs)
]


def test_initialization(cassandra_session):
"""Test VectorDB initialization."""
from agno.vectordb.cassandra import Cassandra

# Test successful initialization
db = Cassandra(table_name="test_vectors", keyspace="test_vectordb", session=cassandra_session)
assert db.table_name == "test_vectors"
assert db.keyspace == "test_vectordb"

# Test initialization failures
with pytest.raises(ValueError):
Cassandra(table_name="", keyspace="test_vectordb", session=cassandra_session)

with pytest.raises(ValueError):
Cassandra(table_name="test_vectors", keyspace="", session=cassandra_session)

with pytest.raises(ValueError):
Cassandra(table_name="test_vectors", keyspace="test_vectordb", session=None)


def test_insert_and_search(vector_db):
"""Test document insertion and search functionality."""
# Insert test documents
docs = create_test_documents(1)
vector_db.insert(docs)

time.sleep(1)

# Test search functionality
results = vector_db.search("test document", limit=1)
assert len(results) == 1
assert all(isinstance(doc, Document) for doc in results)

# Test vector search
results = vector_db.vector_search("test document 1", limit=2)


def test_document_existence(vector_db):
"""Test document existence checking methods."""
docs = create_test_documents(1)
vector_db.insert(docs)

# Test by document object
assert vector_db.doc_exists(docs[0]) is True

# Test by name
assert vector_db.name_exists("test_doc_0") is True
assert vector_db.name_exists("nonexistent") is False

# Test by ID
assert vector_db.id_exists("doc_0") is True
assert vector_db.id_exists("nonexistent") is False


def test_upsert(vector_db):
"""Test upsert functionality."""
# Initial insert
docs = create_test_documents(1)
vector_db.insert(docs)

# Modify document and upsert
modified_doc = Document(
id=docs[0].id, content="Modified content", meta_data={"type": "modified"}, name=docs[0].name
)
vector_db.upsert([modified_doc])

# Verify modification
results = vector_db.search("Modified content", limit=1)
assert len(results) == 1
assert results[0].content == "Modified content"
assert results[0].meta_data["type"] == "modified"


def test_delete_and_drop(vector_db):
"""Test delete and drop functionality."""
# Insert documents
docs = create_test_documents()
vector_db.insert(docs)

# Test delete
assert vector_db.delete() is True
results = vector_db.search("test document", limit=5)
assert len(results) == 0

# Test drop
vector_db.drop()
assert vector_db.exists() is False


def test_exists(vector_db):
"""Test table existence checking."""
assert vector_db.exists() is True
vector_db.drop()
assert vector_db.exists() is False

0 comments on commit 53d1a77

Please sign in to comment.