Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add fastembed encoder #68

Merged
merged 2 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
408 changes: 406 additions & 2 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ readme = "README.md"
packages = [{include = "semantic_router"}]

[tool.poetry.dependencies]
python = "^3.9"
python = ">=3.9,<3.12"
pydantic = "^1.8.2"
openai = "^1.3.9"
cohere = "^4.32"
Expand All @@ -22,6 +22,7 @@ pinecone-text = "^0.7.0"
colorlog = "^6.8.0"
pyyaml = "^6.0.1"
pytest-asyncio = "^0.23.2"
fastembed = "^0.1.3"


[tool.poetry.group.dev.dependencies]
Expand Down
9 changes: 8 additions & 1 deletion semantic_router/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,12 @@
from semantic_router.encoders.bm25 import BM25Encoder
from semantic_router.encoders.cohere import CohereEncoder
from semantic_router.encoders.openai import OpenAIEncoder
from semantic_router.encoders.fastembed import FastEmbedEncoder

__all__ = ["BaseEncoder", "CohereEncoder", "OpenAIEncoder", "BM25Encoder"]
__all__ = [
"BaseEncoder",
"CohereEncoder",
"OpenAIEncoder",
"BM25Encoder",
"FastEmbedEncoder",
]
45 changes: 45 additions & 0 deletions semantic_router/encoders/fastembed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Any, List, Optional
import numpy as np
from pydantic import BaseModel, PrivateAttr


class FastEmbedEncoder(BaseModel):
type: str = "fastembed"
model_name: str = "BAAI/bge-small-en-v1.5"
max_length: int = 512
cache_dir: Optional[str] = None
threads: Optional[int] = None
_client: Any = PrivateAttr()

def __init__(self, **data):
super().__init__(**data)
self._client = self._initialize_client()

def _initialize_client(self):
try:
from fastembed.embedding import FlagEmbedding as Embedding
except ImportError:
raise ImportError(

Check warning on line 22 in semantic_router/encoders/fastembed.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/encoders/fastembed.py#L21-L22

Added lines #L21 - L22 were not covered by tests
"Please install fastembed to use FastEmbedEncoder"
"You can install it with: `pip install fastembed`"
)

embedding_args = {
"model_name": self.model_name,
"max_length": self.max_length,
"cache_dir": self.cache_dir,
"threads": self.threads,
}

embedding_args = {k: v for k, v in embedding_args.items() if v is not None}

embedding = Embedding(**embedding_args)
return embedding

def __call__(self, docs: list[str]) -> list[list[float]]:
try:
embeds: List[np.ndarray] = list(self._client.embed(docs))
embeddings: List[List[float]] = [e.tolist() for e in embeds]
return embeddings
except Exception as e:
raise ValueError(f"FastEmbed embed failed. Error: {e}")

Check warning on line 45 in semantic_router/encoders/fastembed.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/encoders/fastembed.py#L44-L45

Added lines #L44 - L45 were not covered by tests
Empty file added test_output.txt
Empty file.
10 changes: 10 additions & 0 deletions tests/unit/encoders/test_fastembed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from semantic_router.encoders import FastEmbedEncoder


class TestFastEmbedEncoder:
def test_fastembed_encoder(self):
encode = FastEmbedEncoder()
test_docs = ["This is a test", "This is another test"]

embeddings = encode(test_docs)
assert isinstance(embeddings, list)
Loading