Skip to content

Commit 7e63ac9

Browse files
authored
Merge branch 'main' into embeddings
2 parents 59f22d8 + e75354d commit 7e63ac9

File tree

7 files changed

+161
-28
lines changed

7 files changed

+161
-28
lines changed

libs/kotaemon/kotaemon/embeddings/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .base import BaseEmbeddings
22
from .endpoint_based import EndpointEmbeddings
3+
from .fastembed import FastEmbedEmbeddings
34
from .langchain_based import (
45
LCAzureOpenAIEmbeddings,
56
LCCohereEmbdeddings,
@@ -17,4 +18,5 @@
1718
"LCHuggingFaceEmbeddings",
1819
"OpenAIEmbeddings",
1920
"AzureOpenAIEmbeddings",
21+
"FastEmbedEmbeddings",
2022
]

libs/kotaemon/kotaemon/embeddings/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,12 @@ async def ainvoke(
1818
self, text: str | list[str] | Document | list[Document], *args, **kwargs
1919
) -> list[DocumentWithEmbedding]:
2020
raise NotImplementedError
21+
22+
def prepare_input(
23+
self, text: str | list[str] | Document | list[Document]
24+
) -> list[Document]:
25+
if isinstance(text, (str, Document)):
26+
return [Document(content=text)]
27+
elif isinstance(text, list):
28+
return [Document(content=_) for _ in text]
29+
return text
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from typing import TYPE_CHECKING, Optional
2+
3+
from kotaemon.base import Document, DocumentWithEmbedding, Param
4+
5+
from .base import BaseEmbeddings
6+
7+
if TYPE_CHECKING:
8+
from fastembed import TextEmbedding
9+
10+
11+
class FastEmbedEmbeddings(BaseEmbeddings):
12+
"""Utilize fastembed library for embeddings locally without GPU.
13+
14+
Supported model: https://qdrant.github.io/fastembed/examples/Supported_Models/
15+
Code: https://github.com/qdrant/fastembed
16+
"""
17+
18+
model_name: str = Param(
19+
"BAAI/bge-small-en-v1.5",
20+
help=(
21+
"Model name for fastembed. "
22+
"Supported model: "
23+
"https://qdrant.github.io/fastembed/examples/Supported_Models/"
24+
),
25+
)
26+
batch_size: int = Param(
27+
256,
28+
help="Batch size for embeddings. Higher values use more memory, but are faster",
29+
)
30+
parallel: Optional[int] = Param(
31+
None,
32+
help=(
33+
"Number of threads to use for embeddings. "
34+
"If > 1, data-parallel encoding will be used. "
35+
"If 0, use all available CPUs. "
36+
"If None, use default onnxruntime threading. "
37+
"Defaults to None"
38+
),
39+
)
40+
41+
@Param.auto()
42+
def client_(self) -> "TextEmbedding":
43+
from fastembed import TextEmbedding
44+
45+
return TextEmbedding(model_name=self.model_name)
46+
47+
def invoke(
48+
self, text: str | list[str] | Document | list[Document], *args, **kwargs
49+
) -> list[DocumentWithEmbedding]:
50+
input_ = self.prepare_input(text)
51+
embeddings = self.client_.embed(
52+
[_.content for _ in input_],
53+
batch_size=self.batch_size,
54+
parallel=self.parallel,
55+
)
56+
return [
57+
DocumentWithEmbedding(
58+
content=doc,
59+
embedding=list(embedding),
60+
)
61+
for doc, embedding in zip(input_, embeddings)
62+
]
63+
64+
async def ainvoke(
65+
self, text: str | list[str] | Document | list[Document], *args, **kwargs
66+
) -> list[DocumentWithEmbedding]:
67+
"""Fastembed does not support async API."""
68+
return self.invoke(text, *args, **kwargs)

libs/kotaemon/kotaemon/embeddings/openai.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,6 @@ def max_retries_(self):
4141
return DEFAULT_MAX_RETRIES
4242
return self.max_retries
4343

44-
def prepare_input(
45-
self, text: str | list[str] | Document | list[Document]
46-
) -> list[Document]:
47-
if isinstance(text, (str, Document)):
48-
return [Document(content=text)]
49-
elif isinstance(text, list):
50-
return [Document(content=_) for _ in text]
51-
return text
52-
5344
def prepare_client(self, async_version: bool = False):
5445
"""Get the OpenAI client
5546

libs/kotaemon/pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ packages.find.exclude = ["tests*", "env*"]
1111
# metadata and dependencies
1212
[project]
1313
name = "kotaemon"
14-
version = "0.3.9"
14+
version = "0.3.10"
1515
requires-python = ">= 3.10"
1616
description = "Kotaemon core library for AI development."
1717
dependencies = [
@@ -61,6 +61,7 @@ adv = [
6161
"elasticsearch",
6262
"llama-cpp-python",
6363
"pdfservices-sdk @ git+https://github.com/niallcm/pdfservices-python-sdk.git@bump-and-unfreeze-requirements",
64+
"fastembed",
6465
]
6566
dev = [
6667
"ipython",

libs/kotaemon/tests/test_embedding_models.py

Lines changed: 78 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from kotaemon.base import Document
88
from kotaemon.embeddings import (
99
AzureOpenAIEmbeddings,
10+
FastEmbedEmbeddings,
1011
LCAzureOpenAIEmbeddings,
1112
LCCohereEmbdeddings,
1213
LCHuggingFaceEmbeddings,
@@ -20,6 +21,13 @@
2021
openai_embedding = CreateEmbeddingResponse.model_validate(json.load(f))
2122

2223

24+
def assert_embedding_result(output):
25+
assert isinstance(output, list)
26+
assert isinstance(output[0], Document)
27+
assert isinstance(output[0].embedding, list)
28+
assert isinstance(output[0].embedding[0], float)
29+
30+
2331
@patch(
2432
"openai.resources.embeddings.Embeddings.create",
2533
side_effect=lambda *args, **kwargs: openai_embedding,
@@ -32,10 +40,7 @@ def test_lcazureopenai_embeddings_raw(openai_embedding_call):
3240
openai_api_key="some-key",
3341
)
3442
output = model("Hello world")
35-
assert isinstance(output, list)
36-
assert isinstance(output[0], Document)
37-
assert isinstance(output[0].embedding, list)
38-
assert isinstance(output[0].embedding[0], float)
43+
assert_embedding_result(output)
3944
openai_embedding_call.assert_called()
4045

4146

@@ -51,10 +56,67 @@ def test_lcazureopenai_embeddings_batch_raw(openai_embedding_call):
5156
openai_api_key="some-key",
5257
)
5358
output = model(["Hello world", "Goodbye world"])
54-
assert isinstance(output, list)
55-
assert isinstance(output[0], Document)
56-
assert isinstance(output[0].embedding, list)
57-
assert isinstance(output[0].embedding[0], float)
59+
assert_embedding_result(output)
60+
openai_embedding_call.assert_called()
61+
62+
63+
@patch(
64+
"openai.resources.embeddings.Embeddings.create",
65+
side_effect=lambda *args, **kwargs: openai_embedding,
66+
)
67+
def test_azureopenai_embeddings_raw(openai_embedding_call):
68+
model = AzureOpenAIEmbeddings(
69+
azure_endpoint="https://test.openai.azure.com/",
70+
api_key="some-key",
71+
api_version="version",
72+
azure_deployment="text-embedding-ada-002",
73+
)
74+
output = model("Hello world")
75+
assert_embedding_result(output)
76+
openai_embedding_call.assert_called()
77+
78+
79+
@patch(
80+
"openai.resources.embeddings.Embeddings.create",
81+
side_effect=lambda *args, **kwargs: openai_embedding_batch,
82+
)
83+
def test_azureopenai_embeddings_batch_raw(openai_embedding_call):
84+
model = AzureOpenAIEmbeddings(
85+
azure_endpoint="https://test.openai.azure.com/",
86+
api_key="some-key",
87+
api_version="version",
88+
azure_deployment="text-embedding-ada-002",
89+
)
90+
output = model(["Hello world", "Goodbye world"])
91+
assert_embedding_result(output)
92+
openai_embedding_call.assert_called()
93+
94+
95+
@patch(
96+
"openai.resources.embeddings.Embeddings.create",
97+
side_effect=lambda *args, **kwargs: openai_embedding,
98+
)
99+
def test_openai_embeddings_raw(openai_embedding_call):
100+
model = OpenAIEmbeddings(
101+
api_key="some-key",
102+
model="text-embedding-ada-002",
103+
)
104+
output = model("Hello world")
105+
assert_embedding_result(output)
106+
openai_embedding_call.assert_called()
107+
108+
109+
@patch(
110+
"openai.resources.embeddings.Embeddings.create",
111+
side_effect=lambda *args, **kwargs: openai_embedding_batch,
112+
)
113+
def test_openai_embeddings_batch_raw(openai_embedding_call):
114+
model = OpenAIEmbeddings(
115+
api_key="some-key",
116+
model="text-embedding-ada-002",
117+
)
118+
output = model(["Hello world", "Goodbye world"])
119+
assert_embedding_result(output)
58120
openai_embedding_call.assert_called()
59121

60122

@@ -148,10 +210,7 @@ def test_lchuggingface_embeddings(
148210
)
149211

150212
output = model("Hello World")
151-
assert isinstance(output, list)
152-
assert isinstance(output[0], Document)
153-
assert isinstance(output[0].embedding, list)
154-
assert isinstance(output[0].embedding[0], float)
213+
assert_embedding_result(output)
155214
sentence_transformers_init.assert_called()
156215
langchain_huggingface_embedding_call.assert_called()
157216

@@ -166,8 +225,11 @@ def test_lccohere_embeddings(langchain_cohere_embedding_call):
166225
)
167226

168227
output = model("Hello World")
169-
assert isinstance(output, list)
170-
assert isinstance(output[0], Document)
171-
assert isinstance(output[0].embedding, list)
172-
assert isinstance(output[0].embedding[0], float)
228+
assert_embedding_result(output)
173229
langchain_cohere_embedding_call.assert_called()
230+
231+
232+
def test_fastembed_embeddings():
233+
model = FastEmbedEmbeddings()
234+
output = model("Hello World")
235+
assert_embedding_result(output)

libs/ktem/flowsettings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
if config("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT", default=""):
5858
KH_EMBEDDINGS["azure"] = {
5959
"spec": {
60-
"__type__": "kotaemon.embeddings.AzureOpenAIEmbeddings",
60+
"__type__": "kotaemon.embeddings.LCAzureOpenAIEmbeddings",
6161
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
6262
"api_key": config("AZURE_OPENAI_API_KEY", default=""),
6363
"api_version": config("OPENAI_API_VERSION", default="")
@@ -88,7 +88,7 @@
8888
if len(KH_EMBEDDINGS) < 1:
8989
KH_EMBEDDINGS["openai"] = {
9090
"spec": {
91-
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
91+
"__type__": "kotaemon.embeddings.LCOpenAIEmbeddings",
9292
"base_url": config("OPENAI_API_BASE", default="")
9393
or "https://api.openai.com/v1",
9494
"api_key": config("OPENAI_API_KEY", default=""),

0 commit comments

Comments
 (0)