Skip to content

Commit

Permalink
Merge pull request #397 from aurelio-labs/vittorio/add-async-get-rout…
Browse files Browse the repository at this point in the history
…es-method-to-pinecone-index

feat: Implemented aget_routes async method for pinecone index
  • Loading branch information
jamescalam authored Aug 23, 2024
2 parents db451eb + 7f2cfae commit 567e4c9
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 9 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
project = "Semantic Router"
copyright = "2024, Aurelio AI"
author = "Aurelio AI"
release = "0.0.60"
release = "0.0.61"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "semantic-router"
version = "0.0.60"
version = "0.0.61"
description = "Super fast semantic router for AI decision making"
authors = [
"James Briggs <james@aurelio.ai>",
Expand Down
2 changes: 1 addition & 1 deletion semantic_router/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

__all__ = ["RouteLayer", "HybridRouteLayer", "Route", "LayerConfig"]

__version__ = "0.0.60"
__version__ = "0.0.61"
11 changes: 11 additions & 0 deletions semantic_router/index/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,17 @@ async def aquery(
"""
raise NotImplementedError("This method should be implemented by subclasses.")

def aget_routes(self):
"""
Asynchronously get a list of route and utterance objects currently stored in the index.
This method should be implemented by subclasses.
:returns: A list of tuples, each containing a route name and an associated utterance.
:rtype: list[tuple]
:raises NotImplementedError: If the method is not implemented by the subclass.
"""
raise NotImplementedError("This method should be implemented by subclasses.")

def delete_index(self):
"""
Deletes or resets the index.
Expand Down
3 changes: 3 additions & 0 deletions semantic_router/index/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ async def aquery(
route_names = [self.routes[i] for i in idx]
return scores, route_names

def aget_routes(self):
logger.error("Sync remove is not implemented for LocalIndex.")

def delete(self, route_name: str):
"""
Delete all records of a specific route from the index.
Expand Down
108 changes: 108 additions & 0 deletions semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,18 @@ async def aquery(
route_names = [result["metadata"]["sr_route"] for result in results["matches"]]
return np.array(scores), route_names

async def aget_routes(self) -> list[tuple]:
"""
Asynchronously get a list of route and utterance objects currently stored in the index.
Returns:
List[Tuple]: A list of (route_name, utterance) objects.
"""
if self.async_client is None or self.host is None:
raise ValueError("Async client or host are not initialized.")

return await self._async_get_routes()

def delete_index(self):
self.client.delete_index(self.index_name)

Expand Down Expand Up @@ -584,5 +596,101 @@ async def _async_describe_index(self, name: str):
async with self.async_client.get(f"{self.base_url}/indexes/{name}") as response:
return await response.json(content_type=None)

async def _async_get_all(
self, prefix: Optional[str] = None, include_metadata: bool = False
) -> tuple[list[str], list[dict]]:
"""
Retrieves all vector IDs from the Pinecone index using pagination asynchronously.
"""
if self.index is None:
raise ValueError("Index is None, could not retrieve vector IDs.")

all_vector_ids = []
next_page_token = None

if prefix:
prefix_str = f"?prefix={prefix}"
else:
prefix_str = ""

list_url = f"https://{self.host}/vectors/list{prefix_str}"
params: dict = {}
if self.namespace:
params["namespace"] = self.namespace
metadata = []

while True:
if next_page_token:
params["paginationToken"] = next_page_token

async with self.async_client.get(
list_url, params=params, headers={"Api-Key": self.api_key}
) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Error fetching vectors: {error_text}")
break

response_data = await response.json(content_type=None)

vector_ids = [vec["id"] for vec in response_data.get("vectors", [])]
if not vector_ids:
break
all_vector_ids.extend(vector_ids)

if include_metadata:
metadata_tasks = [self._async_fetch_metadata(id) for id in vector_ids]
metadata_results = await asyncio.gather(*metadata_tasks)
metadata.extend(metadata_results)

next_page_token = response_data.get("pagination", {}).get("next")
if not next_page_token:
break

return all_vector_ids, metadata

async def _async_fetch_metadata(self, vector_id: str) -> dict:
"""
Fetch metadata for a single vector ID asynchronously using the async_client.
"""
url = f"https://{self.host}/vectors/fetch"

params = {
"ids": [vector_id],
}

headers = {
"Api-Key": self.api_key,
}

async with self.async_client.get(
url, params=params, headers=headers
) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Error fetching metadata: {error_text}")
return {}

try:
response_data = await response.json(content_type=None)
except Exception as e:
logger.warning(f"No metadata found for vector {vector_id}: {e}")
return {}

return (
response_data.get("vectors", {}).get(vector_id, {}).get("metadata", {})
)

async def _async_get_routes(self) -> list[tuple]:
"""
Gets a list of route and utterance objects currently stored in the index.
Returns:
List[Tuple]: A list of (route_name, utterance) objects.
"""
_, metadata = await self._async_get_all(include_metadata=True)
route_tuples = [(x["sr_route"], x["sr_utterance"]) for x in metadata]
return route_tuples

def __len__(self):
return self.index.describe_index_stats()["total_vector_count"]
4 changes: 4 additions & 0 deletions semantic_router/index/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from semantic_router.index.base import BaseIndex
from semantic_router.schema import Metric
from semantic_router.utils.logger import logger


class MetricPgVecOperatorMap(Enum):
Expand Down Expand Up @@ -456,6 +457,9 @@ def delete_index(self) -> None:
cur.execute(f"DROP TABLE IF EXISTS {table_name}")
self.conn.commit()

def aget_routes(self):
logger.error("Sync remove is not implemented for PostgresIndex.")

def __len__(self):
"""
Returns the total number of vectors in the index.
Expand Down
3 changes: 3 additions & 0 deletions semantic_router/index/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,9 @@ async def aquery(
route_names = [result.payload[SR_ROUTE_PAYLOAD_KEY] for result in results]
return np.array(scores), route_names

def aget_routes(self):
logger.error("Sync remove is not implemented for QdrantIndex.")

def delete_index(self):
self.client.delete_collection(self.index_name)

Expand Down
11 changes: 5 additions & 6 deletions tests/unit/encoders/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from semantic_router.encoders import VitEncoder

test_model_name = "aurelio-ai/sr-test-vit"
vit_encoder = VitEncoder(name=test_model_name)
embed_dim = 32

if torch.cuda.is_available():
Expand Down Expand Up @@ -44,15 +43,11 @@ def test_vit_encoder__import_errors_torch(self, mocker):
with pytest.raises(ImportError):
VitEncoder()

def test_vit_encoder__import_errors_torchvision(self, mocker):
mocker.patch.dict("sys.modules", {"torchvision": None})
with pytest.raises(ImportError):
VitEncoder()

@pytest.mark.skipif(
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_vit_encoder_initialization(self):
vit_encoder = VitEncoder(name=test_model_name)
assert vit_encoder.name == test_model_name
assert vit_encoder.type == "huggingface"
assert vit_encoder.score_threshold == 0.5
Expand All @@ -62,6 +57,7 @@ def test_vit_encoder_initialization(self):
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_vit_encoder_call(self, dummy_pil_image):
vit_encoder = VitEncoder(name=test_model_name)
encoded_images = vit_encoder([dummy_pil_image] * 3)

assert len(encoded_images) == 3
Expand All @@ -71,6 +67,7 @@ def test_vit_encoder_call(self, dummy_pil_image):
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_vit_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image):
vit_encoder = VitEncoder(name=test_model_name)
encoded_images = vit_encoder([dummy_pil_image, misshaped_pil_image])

assert len(encoded_images) == 2
Expand All @@ -80,6 +77,7 @@ def test_vit_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image):
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_vit_encoder_process_images_device(self, dummy_pil_image):
vit_encoder = VitEncoder(name=test_model_name)
imgs = vit_encoder._process_images([dummy_pil_image] * 3)["pixel_values"]

assert imgs.device.type == device
Expand All @@ -88,6 +86,7 @@ def test_vit_encoder_process_images_device(self, dummy_pil_image):
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_vit_encoder_ensure_rgb(self, dummy_black_and_white_img):
vit_encoder = VitEncoder(name=test_model_name)
rgb_image = vit_encoder._ensure_rgb(dummy_black_and_white_img)

assert rgb_image.mode == "RGB"
Expand Down

0 comments on commit 567e4c9

Please sign in to comment.