diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py index ac27ebb4..4bb1eb37 100644 --- a/semantic_router/encoders/__init__.py +++ b/semantic_router/encoders/__init__.py @@ -2,5 +2,12 @@ from semantic_router.encoders.bm25 import BM25Encoder from semantic_router.encoders.cohere import CohereEncoder from semantic_router.encoders.openai import OpenAIEncoder +from semantic_router.encoders.fastembed import FastEmbedEncoder -__all__ = ["BaseEncoder", "CohereEncoder", "OpenAIEncoder", "BM25Encoder"] +__all__ = [ + "BaseEncoder", + "CohereEncoder", + "OpenAIEncoder", + "BM25Encoder", + "FastEmbedEncoder", +] diff --git a/semantic_router/encoders/fastembed.py b/semantic_router/encoders/fastembed.py index 6b700ade..d324058d 100644 --- a/semantic_router/encoders/fastembed.py +++ b/semantic_router/encoders/fastembed.py @@ -1,17 +1,21 @@ -from typing import List, Optional - +from typing import Any, List, Optional import numpy as np -from semantic_router.encoders.base import BaseEncoder +from pydantic import BaseModel, PrivateAttr -class FastEmbedEncoder(BaseEncoder): +class FastEmbedEncoder(BaseModel): + type: str = "fastembed" model_name: str = "BAAI/bge-small-en-v1.5" max_length: int = 512 cache_dir: Optional[str] = None threads: Optional[int] = None - type: str = "fastembed" + _client: Any = PrivateAttr() + + def __init__(self, **data): + super().__init__(**data) + self._client = self._initialize_client() - def init(self): + def _initialize_client(self): try: from fastembed.embedding import FlagEmbedding as Embedding except ImportError: @@ -23,20 +27,19 @@ def init(self): embedding_args = { "model_name": self.model_name, "max_length": self.max_length, + "cache_dir": self.cache_dir, + "threads": self.threads, } - if self.cache_dir is not None: - embedding_args["cache_dir"] = self.cache_dir - if self.threads is not None: - embedding_args["threads"] = self.threads - self.client = Embedding(**embedding_args) + embedding_args = {k: v for k, v in embedding_args.items() if v is not None} + + embedding = Embedding(**embedding_args) + return embedding def __call__(self, docs: list[str]) -> list[list[float]]: try: - embeds: List[np.ndarray] = list(self.client.embed(docs)) - + embeds: List[np.ndarray] = list(self._client.embed(docs)) embeddings: List[List[float]] = [e.tolist() for e in embeds] - return embeddings except Exception as e: raise ValueError(f"FastEmbed embed failed. Error: {e}") diff --git a/test_output.txt b/test_output.txt new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/encoders/test_fastembed.py b/tests/unit/encoders/test_fastembed.py new file mode 100644 index 00000000..5efdbfbe --- /dev/null +++ b/tests/unit/encoders/test_fastembed.py @@ -0,0 +1,10 @@ +from semantic_router.encoders import FastEmbedEncoder + + +class TestFastEmbedEncoder: + def test_fastembed_encoder(self): + encode = FastEmbedEncoder() + test_docs = ["This is a test", "This is another test"] + + embeddings = encode(test_docs) + assert isinstance(embeddings, list)