Skip to content

Commit

Permalink
fastembed upd & test added
Browse files Browse the repository at this point in the history
  • Loading branch information
kdcokenny committed Jan 4, 2024
1 parent be312b0 commit f2f6e5f
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 15 deletions.
9 changes: 8 additions & 1 deletion semantic_router/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,12 @@
from semantic_router.encoders.bm25 import BM25Encoder
from semantic_router.encoders.cohere import CohereEncoder
from semantic_router.encoders.openai import OpenAIEncoder
from semantic_router.encoders.fastembed import FastEmbedEncoder

__all__ = ["BaseEncoder", "CohereEncoder", "OpenAIEncoder", "BM25Encoder"]
__all__ = [
"BaseEncoder",
"CohereEncoder",
"OpenAIEncoder",
"BM25Encoder",
"FastEmbedEncoder",
]
31 changes: 17 additions & 14 deletions semantic_router/encoders/fastembed.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from typing import List, Optional

from typing import Any, List, Optional
import numpy as np
from semantic_router.encoders.base import BaseEncoder
from pydantic import BaseModel, PrivateAttr


class FastEmbedEncoder(BaseEncoder):
class FastEmbedEncoder(BaseModel):
type: str = "fastembed"
model_name: str = "BAAI/bge-small-en-v1.5"
max_length: int = 512
cache_dir: Optional[str] = None
threads: Optional[int] = None
type: str = "fastembed"
_client: Any = PrivateAttr()

def __init__(self, **data):
super().__init__(**data)
self._client = self._initialize_client()

def init(self):
def _initialize_client(self):
try:
from fastembed.embedding import FlagEmbedding as Embedding
except ImportError:
Expand All @@ -23,20 +27,19 @@ def init(self):
embedding_args = {
"model_name": self.model_name,
"max_length": self.max_length,
"cache_dir": self.cache_dir,
"threads": self.threads,
}
if self.cache_dir is not None:
embedding_args["cache_dir"] = self.cache_dir
if self.threads is not None:
embedding_args["threads"] = self.threads

self.client = Embedding(**embedding_args)
embedding_args = {k: v for k, v in embedding_args.items() if v is not None}

embedding = Embedding(**embedding_args)
return embedding

def __call__(self, docs: list[str]) -> list[list[float]]:
try:
embeds: List[np.ndarray] = list(self.client.embed(docs))

embeds: List[np.ndarray] = list(self._client.embed(docs))
embeddings: List[List[float]] = [e.tolist() for e in embeds]

return embeddings
except Exception as e:
raise ValueError(f"FastEmbed embed failed. Error: {e}")

Check warning on line 45 in semantic_router/encoders/fastembed.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/encoders/fastembed.py#L44-L45

Added lines #L44 - L45 were not covered by tests
Empty file added test_output.txt
Empty file.
10 changes: 10 additions & 0 deletions tests/unit/encoders/test_fastembed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from semantic_router.encoders import FastEmbedEncoder


class TestFastEmbedEncoder:
def test_fastembed_encoder(self):
encode = FastEmbedEncoder()
test_docs = ["This is a test", "This is another test"]

embeddings = encode(test_docs)
assert isinstance(embeddings, list)

0 comments on commit f2f6e5f

Please sign in to comment.