-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* chore: Poetry + precommit * chore: Poetry * feat: Encoders * chore: Clean notebook outputs * feat: Added encoders to query * chore: Added a note to fix delete for Pinecone * chore: Linting fix * fix: Fix Pinecone deletion by file_url * fix: Pinecone delete * Small tweaks * Fix linting --------- Co-authored-by: Ismail Pelaseyed <homanp@gmail.com>
- Loading branch information
Showing
21 changed files
with
710 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,8 @@ | ||
API_BASE_URL=https://rag.superagent.sh | ||
OPENAI_API_KEY= | ||
COHERE_API_KEY= | ||
OPENAI_API_KEY= | ||
|
||
# Optional for walkthrough | ||
PINECONE_API_KEY= | ||
PINECONE_HOST= | ||
PINECONE_INDEX= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
.venv | ||
.env | ||
__pycache__/ | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import requests\n", | ||
"from dotenv import load_dotenv\n", | ||
"load_dotenv()\n", | ||
"\n", | ||
"API_URL = os.environ.get('API_BASE_URL', 'http://localhost:8000')\n", | ||
"PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY', '')\n", | ||
"PINECONE_INDEX = os.environ.get('PINECONE_INDEX', '')\n", | ||
"PINECONE_HOST = os.environ.get('PINECONE_HOST', '')\n", | ||
"\n", | ||
"print(\"API_URL:\", API_URL)\n", | ||
"print(\"PINECONE_API_KEY:\", PINECONE_API_KEY)\n", | ||
"print(\"PINECONE_INDEX:\", PINECONE_INDEX)\n", | ||
"print(\"PINECONE_HOST:\", PINECONE_HOST)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Ingest a file\n", | ||
"url = f\"{API_URL}/api/v1/ingest\"\n", | ||
"\n", | ||
"payload = {\n", | ||
" \"files\": [\n", | ||
" {\n", | ||
" \"type\": \"PDF\",\n", | ||
" \"url\": \"https://arxiv.org/pdf/2402.05131.pdf\"\n", | ||
" }\n", | ||
" ],\n", | ||
" \"vector_database\": {\n", | ||
" \"type\": \"pinecone\",\n", | ||
" \"config\": {\n", | ||
" \"api_key\": PINECONE_API_KEY,\n", | ||
" \"host\": PINECONE_HOST,\n", | ||
" }\n", | ||
" },\n", | ||
" \"index_name\": PINECONE_INDEX,\n", | ||
" \"encoder\": \"openai\",\n", | ||
"}\n", | ||
"\n", | ||
"response = requests.post(url, json=payload)\n", | ||
"\n", | ||
"print(response.json())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Query the index\n", | ||
"query_url = f\"{API_URL}/api/v1/query\"\n", | ||
"\n", | ||
"query_payload = {\n", | ||
" \"input\": \"What is the best chunk strategy?\",\n", | ||
" \"vector_database\": {\n", | ||
" \"type\": \"pinecone\",\n", | ||
" \"config\": {\n", | ||
" \"api_key\": PINECONE_API_KEY,\n", | ||
" \"host\": PINECONE_HOST,\n", | ||
" }\n", | ||
" },\n", | ||
" \"index_name\": PINECONE_INDEX,\n", | ||
" \"encoder\": \"openai\",\n", | ||
"}\n", | ||
"\n", | ||
"query_response = requests.post(query_url, json=query_payload)\n", | ||
"\n", | ||
"print(query_response.json())\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Delete the index\n", | ||
"query_url = f\"{API_URL}/api/v1/delete\"\n", | ||
"\n", | ||
"delete_payload = {\n", | ||
" \"file_url\": \"https://arxiv.org/pdf/2402.05131.pdf\",\n", | ||
" \"vector_database\": {\n", | ||
" \"type\": \"pinecone\",\n", | ||
" \"config\": {\n", | ||
" \"api_key\": PINECONE_API_KEY,\n", | ||
" \"host\": PINECONE_HOST,\n", | ||
" }\n", | ||
" },\n", | ||
" \"index_name\": PINECONE_INDEX,\n", | ||
"}\n", | ||
"\n", | ||
"delete_response = requests.delete(query_url, json=delete_payload)\n", | ||
"\n", | ||
"print(delete_response.json())" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from encoders.base import BaseEncoder | ||
from encoders.bm25 import BM25Encoder | ||
from encoders.cohere import CohereEncoder | ||
from encoders.fastembed import FastEmbedEncoder | ||
from encoders.huggingface import HuggingFaceEncoder | ||
from encoders.openai import OpenAIEncoder | ||
|
||
__all__ = [ | ||
"BaseEncoder", | ||
"CohereEncoder", | ||
"OpenAIEncoder", | ||
"BM25Encoder", | ||
"FastEmbedEncoder", | ||
"HuggingFaceEncoder", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from typing import List | ||
|
||
from pydantic.v1 import BaseModel, Field | ||
|
||
|
||
class BaseEncoder(BaseModel): | ||
name: str | ||
score_threshold: float | ||
type: str = Field(default="base") | ||
dimension: int = Field(default=1536) | ||
|
||
class Config: | ||
arbitrary_types_allowed = True | ||
|
||
def __call__(self, docs: List[str]) -> List[List[float]]: | ||
raise NotImplementedError("Subclasses must implement this method") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from typing import Any, Dict, List, Optional | ||
|
||
from semantic_router.encoders import BaseEncoder | ||
from semantic_router.utils.logger import logger | ||
|
||
|
||
class BM25Encoder(BaseEncoder): | ||
model: Optional[Any] = None | ||
idx_mapping: Optional[Dict[int, int]] = None | ||
type: str = "sparse" | ||
|
||
def __init__( | ||
self, | ||
name: str = "bm25", | ||
score_threshold: float = 0.82, | ||
use_default_params: bool = True, | ||
): | ||
super().__init__(name=name, score_threshold=score_threshold) | ||
try: | ||
from pinecone_text.sparse import BM25Encoder as encoder | ||
except ImportError: | ||
raise ImportError( | ||
"Please install pinecone-text to use BM25Encoder. " | ||
"You can install it with: `pip install 'semantic-router[hybrid]'`" | ||
) | ||
|
||
self.model = encoder() | ||
|
||
if use_default_params: | ||
logger.info("Downloading and initializing default sBM25 model parameters.") | ||
self.model = encoder.default() | ||
self._set_idx_mapping() | ||
|
||
def _set_idx_mapping(self): | ||
params = self.model.get_params() | ||
doc_freq = params["doc_freq"] | ||
if isinstance(doc_freq, dict): | ||
indices = doc_freq["indices"] | ||
self.idx_mapping = {int(idx): i for i, idx in enumerate(indices)} | ||
else: | ||
raise TypeError("Expected a dictionary for 'doc_freq'") | ||
|
||
def __call__(self, docs: List[str]) -> List[List[float]]: | ||
if self.model is None or self.idx_mapping is None: | ||
raise ValueError("Model or index mapping is not initialized.") | ||
if len(docs) == 1: | ||
sparse_dicts = self.model.encode_queries(docs) | ||
elif len(docs) > 1: | ||
sparse_dicts = self.model.encode_documents(docs) | ||
else: | ||
raise ValueError("No documents to encode.") | ||
|
||
embeds = [[0.0] * len(self.idx_mapping)] * len(docs) | ||
for i, output in enumerate(sparse_dicts): | ||
indices = output["indices"] | ||
values = output["values"] | ||
for idx, val in zip(indices, values): | ||
if idx in self.idx_mapping: | ||
position = self.idx_mapping[idx] | ||
embeds[i][position] = val | ||
return embeds | ||
|
||
def fit(self, docs: List[str]): | ||
if self.model is None: | ||
raise ValueError("Model is not initialized.") | ||
self.model.fit(docs) | ||
self._set_idx_mapping() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from typing import List, Optional | ||
|
||
import cohere | ||
from decouple import config | ||
|
||
from encoders import BaseEncoder | ||
|
||
|
||
class CohereEncoder(BaseEncoder): | ||
client: Optional[cohere.Client] = None | ||
type: str = "cohere" | ||
|
||
def __init__( | ||
self, | ||
name: Optional[str] = None, | ||
cohere_api_key: Optional[str] = None, | ||
score_threshold: float = 0.3, | ||
): | ||
if name is None: | ||
name = config("COHERE_MODEL_NAME", "embed-english-v3.0") | ||
super().__init__(name=name, score_threshold=score_threshold) | ||
cohere_api_key = cohere_api_key or config("COHERE_API_KEY") | ||
if cohere_api_key is None: | ||
raise ValueError("Cohere API key cannot be 'None'.") | ||
try: | ||
self.client = cohere.Client(cohere_api_key) | ||
except Exception as e: | ||
raise ValueError( | ||
f"Cohere API client failed to initialize. Error: {e}" | ||
) from e | ||
|
||
def __call__(self, docs: List[str]) -> List[List[float]]: | ||
if self.client is None: | ||
raise ValueError("Cohere client is not initialized.") | ||
try: | ||
embeds = self.client.embed(docs, input_type="search_query", model=self.name) | ||
return embeds.embeddings | ||
except Exception as e: | ||
raise ValueError(f"Cohere API call failed. Error: {e}") from e |
Oops, something went wrong.