Skip to content

Commit c92aeb1

Browse files
committed
fix: Fix get_embedding in the augmented classes
1 parent 8866bb5 commit c92aeb1

File tree

7 files changed

+686
-386
lines changed

7 files changed

+686
-386
lines changed

poetry.lock

Lines changed: 627 additions & 368 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ url = "https://download.pytorch.org/whl/cpu"
5252
priority = "explicit"
5353

5454
[tool.poetry.dependencies]
55+
python = ">=3.9,<4"
5556
torch = [
5657
{version = ">=2.5.0", markers="extra!='gpu'", source="pytorch-cpu"},
5758
{version = ">=2.5.0", markers="extra=='gpu' and extra!='cpu'"},
@@ -88,7 +89,7 @@ makim = "1.19.0"
8889
virtualenv = "<=20.25.1"
8990
python-dotenv = ">=1.0"
9091
# note: Version 3.7.1 requries spaCy >=3.7.2,<3.8.0
91-
en-core-web-md = {url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.7.1/en_core_web_md-3.7.1-py3-none-any.whl"}
92+
en-core-web-md = {url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.8.0/en_core_web_md-3.8.0-py3-none-any.whl"}
9293

9394
[tool.pytest.ini_options]
9495
testpaths = [

src/rago/augmented/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
EmbeddingType: TypeAlias = Union[
2020
npt.NDArray[np.float64],
21+
npt.NDArray[np.float32],
2122
Tensor,
2223
list[Tensor],
2324
]

src/rago/augmented/openai.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ def get_embedding(self, content: list[str]) -> EmbeddingType:
3838
response = model.embeddings.create(
3939
input=content, model=self.model_name
4040
)
41-
result = np.array(response.data[0].embedding)
42-
result = result.reshape(1, result.size)
41+
result = np.array(
42+
[data.embedding for data in response.data], dtype=np.float32
43+
)
4344

4445
self._save_cache(cache_key, result)
4546

src/rago/augmented/spacy.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,22 @@ def get_embedding(self, content: List[str]) -> EmbeddingType:
3333

3434
model = cast(spacy.language.Language, self.model)
3535
embeddings = []
36+
3637
for text in content:
3738
doc = model(text)
39+
40+
# Ensure the model has proper vectors
41+
if not doc.has_vector:
42+
raise ValueError(f"Text: '{text}' has no valid word vectors!")
43+
3844
embeddings.append(doc.vector)
39-
result = np.array(embeddings)
45+
46+
result = np.array(embeddings, dtype=np.float32)
47+
48+
# Ensure 2D shape (num_texts, embedding_dim)
49+
if result.ndim == 1:
50+
result = result.reshape(1, -1)
51+
4052
self._save_cache(cache_key, result)
4153
return result
4254

tests/test_openai.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,26 @@ def api_key(env) -> str:
2626

2727

2828
@pytest.mark.skip_on_ci
29-
def test_aug_openai(animals_data: list[str], api_key: str) -> None:
29+
@pytest.mark.parametrize(
30+
'question,expected_answer',
31+
[
32+
('Is there any animal larger than a dinosaur?', 'Blue Whale'),
33+
(
34+
'What animal is renowned as the fastest animal on the planet?',
35+
'Peregrine Falcon',
36+
),
37+
('An animal which do pollination?', 'Honey Bee'),
38+
],
39+
)
40+
def test_aug_openai(
41+
animals_data: list[str], api_key: str, question: str, expected_answer: str
42+
) -> None:
3043
"""Test RAG pipeline with OpenAI's GPT."""
3144
logs = {
3245
'augmented': {},
3346
}
3447

35-
query = 'Is there any animal larger than a dinosaur?'
36-
top_k = 3
48+
top_k = 2
3749

3850
ret_string = StringRet(animals_data)
3951
aug_openai = OpenAIAug(
@@ -43,14 +55,14 @@ def test_aug_openai(animals_data: list[str], api_key: str) -> None:
4355
)
4456

4557
ret_result = ret_string.get()
46-
aug_result = aug_openai.search(query, ret_result)
58+
aug_result = aug_openai.search(question, ret_result)
4759

4860
assert aug_openai.top_k == top_k
4961
# note: openai as augmented doesn't work as expected
5062
# it is returning a very poor result
5163
# it needs to be revisited and improved
5264
assert len(aug_result) >= 1
53-
assert 'blue whale' in aug_result[0].lower()
65+
assert expected_answer.lower() in aug_result[0].lower()
5466

5567
# check if logs have been used
5668
assert logs['augmented']

tests/test_spacy.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,44 @@
11
"""Tests for Rago package using SpaCy."""
22

3+
import pytest
4+
35
from rago.augmented import SpaCyAug
4-
from rago.retrieval import StringRet
56

67

7-
def test_aug_spacy(animals_data: list[str]) -> None:
8+
@pytest.mark.parametrize(
9+
'question,expected_answer',
10+
[
11+
('Is there any animal larger than a dinosaur?', 'Blue Whale'),
12+
(
13+
'What animal is renowned as the fastest animal on the planet?',
14+
'Peregrine Falcon',
15+
),
16+
('An animal which do pollination?', 'Honey Bee'),
17+
],
18+
)
19+
def test_aug_spacy(
20+
animals_data: list[str], question: str, expected_answer: str
21+
) -> None:
822
"""Test RAG pipeline with SpaCy."""
923
logs = {
1024
'augmented': {},
1125
}
1226

13-
query = 'Is there any animal larger than a dinosaur?'
14-
top_k = 3
27+
top_k = 2
1528

16-
ret_string = StringRet(animals_data)
1729
aug_openai = SpaCyAug(
30+
model_name='en_core_web_md',
1831
top_k=top_k,
1932
logs=logs['augmented'],
2033
)
2134

22-
ret_result = ret_string.get()
23-
aug_result = aug_openai.search(query, ret_result)
35+
aug_result = aug_openai.search(question, animals_data)
2436

2537
assert aug_openai.top_k == top_k
26-
assert len(aug_result) == top_k
27-
assert any(['blue whale' in result.lower() for result in aug_result])
38+
assert top_k >= len(aug_result)
39+
assert any(
40+
[expected_answer.lower() in result.lower() for result in aug_result]
41+
)
2842

2943
# check if logs have been used
3044
assert logs['augmented']

0 commit comments

Comments
 (0)