Skip to content

Commit c134033

Browse files
committed
fix gemini
1 parent 9731c83 commit c134033

File tree

7 files changed

+179
-15
lines changed

7 files changed

+179
-15
lines changed

src/rago/core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from typing import Any
66

7+
from pydantic import BaseModel
78
from typeguard import typechecked
89

910
from rago.augmented.base import AugmentedBase
@@ -50,7 +51,7 @@ def __init__(
5051
'generation': generation.logs,
5152
}
5253

53-
def prompt(self, query: str, device: str = 'auto') -> str:
54+
def prompt(self, query: str, device: str = 'auto') -> str | BaseModel:
5455
"""Run the pipeline for a specific prompt.
5556
5657
Parameters
@@ -72,7 +73,7 @@ def prompt(self, query: str, device: str = 'auto') -> str:
7273
aug_data = self.augmented.search(query, ret_data)
7374
self.logs['augmented']['result'] = aug_data
7475

75-
gen_data: str = self.generation.generate(query, context=aug_data)
76+
gen_data = self.generation.generate(query, context=aug_data)
7677
self.logs['generation']['result'] = gen_data
7778

7879
return gen_data

src/rago/generation/base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from abc import abstractmethod
6-
from typing import Any, Optional
6+
from typing import Any, Optional, Type
77

88
import torch
99

@@ -27,7 +27,7 @@ class GenerationBase:
2727
prompt_template: str = (
2828
'question: \n```\n{query}\n```\ncontext: ```\n{context}\n```'
2929
)
30-
structured_output: Optional[BaseModel] = None
30+
structured_output: Optional[Type[BaseModel]] = None
3131

3232
# default parameters that can be overwritten by the derived class
3333
default_device_name: str = 'cpu'
@@ -46,7 +46,7 @@ def __init__(
4646
prompt_template: str = '',
4747
output_max_length: int = 500,
4848
device: str = 'auto',
49-
structured_output: Optional[BaseModel] = None,
49+
structured_output: Optional[Type[BaseModel]] = None,
5050
logs: dict[str, Any] = {},
5151
) -> None:
5252
"""Initialize Generation class.
@@ -61,7 +61,7 @@ def __init__(
6161
output_max_length : int
6262
Maximum length of the generated output.
6363
device: str (default=auto)
64-
structured_output: Optional[BaseModel] = None
64+
structured_output: Optional[Type[BaseModel]] = None
6565
logs: dict[str, Any] = {}
6666
"""
6767
self.api_key: str = api_key
@@ -74,7 +74,7 @@ def __init__(
7474
self.prompt_template: str = (
7575
prompt_template or self.default_prompt_template
7676
)
77-
self.structured_output: Optional[BaseModel] = None
77+
self.structured_output: Optional[Type[BaseModel]] = structured_output
7878

7979
if device not in ['cpu', 'cuda', 'auto']:
8080
raise Exception(
@@ -105,7 +105,7 @@ def generate(
105105
self,
106106
query: str,
107107
context: list[str],
108-
) -> str:
108+
) -> str | BaseModel:
109109
"""Generate text with optional language parameter.
110110
111111
Parameters

src/rago/generation/gemini.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import google.generativeai as genai
88
import instructor
99

10+
from pydantic import BaseModel
1011
from typeguard import typechecked
1112

1213
from rago.generation.base import GenerationBase
@@ -24,20 +25,38 @@ def _setup(self) -> None:
2425
model = genai.GenerativeModel(self.model_name)
2526

2627
self.model = (
27-
instructor.from_gemini(model) if self.structured_output else model
28+
instructor.from_gemini(
29+
client=model,
30+
mode=instructor.Mode.GEMINI_JSON,
31+
)
32+
if self.structured_output
33+
else model
2834
)
2935

30-
def generate(self, query: str, context: list[str]) -> str:
36+
def generate(self, query: str, context: list[str]) -> str | BaseModel:
3137
"""Generate text using Gemini model support."""
3238
input_text = self.prompt_template.format(
3339
query=query, context=' '.join(context)
3440
)
3541

42+
if not self.structured_output:
43+
models_params_gen = {'contents': input_text}
44+
response = self.model.generate_content(**models_params_gen)
45+
self.logs['model_params'] = models_params_gen
46+
return cast(str, response.text.strip())
47+
48+
messages = [
49+
{'role': 'user', 'content': input_text},
50+
]
3651
model_params = {
37-
'contents': input_text,
52+
'messages': messages,
53+
'response_model': self.structured_output,
3854
}
3955

40-
response = self.model.generate_content(**model_params)
56+
response = self.model.create(
57+
**model_params,
58+
)
4159

4260
self.logs['model_params'] = model_params
43-
return cast(str, response.text.strip())
61+
62+
return cast(BaseModel, response)

src/rago/generation/openai.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import instructor
88
import openai
99

10+
from pydantic import BaseModel
1011
from typeguard import typechecked
1112

1213
from rago.generation.base import GenerationBase
@@ -30,7 +31,7 @@ def generate(
3031
self,
3132
query: str,
3233
context: list[str],
33-
) -> str:
34+
) -> str | BaseModel:
3435
"""Generate text using OpenAI's API with dynamic model support."""
3536
input_text = self.prompt_template.format(
3637
query=query, context=' '.join(context)
@@ -49,8 +50,15 @@ def generate(
4950
presence_penalty=0.3,
5051
)
5152

53+
if self.structured_output:
54+
model_params['response_model'] = self.structured_output
55+
5256
response = self.model.chat.completions.create(**model_params)
5357

5458
self.logs['model_params'] = model_params
5559

56-
return cast(str, response.choices[0].message.content.strip())
60+
has_choices = hasattr(response, 'choices')
61+
62+
if has_choices and isinstance(response.choices, list):
63+
return cast(str, response.choices[0].message.content.strip())
64+
return cast(BaseModel, response)

tests/models.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""Models used for the unit tests."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Literal
6+
7+
from pydantic import BaseModel, Field
8+
9+
10+
class AnimalModel(BaseModel):
11+
"""Model for animals."""
12+
13+
name: Literal[
14+
'Blue Whale',
15+
'Peregrine Falcon',
16+
'Giant Panda',
17+
'Cheetah',
18+
'Komodo Dragon',
19+
'Arctic Fox',
20+
'Monarch Butterfly',
21+
'Great White Shark',
22+
'Honey Bee',
23+
'Emperor Penguin',
24+
'Unknown',
25+
] = Field(
26+
...,
27+
description='The predicted class label.',
28+
)

tests/test_gemini.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@
22

33
import os
44

5+
from typing import cast
6+
57
import pytest
68

79
from rago import Rago
810
from rago.augmented import SentenceTransformerAug
911
from rago.generation import GeminiGen
1012
from rago.retrieval import StringRet
1113

14+
from .models import AnimalModel
15+
1216

1317
@pytest.fixture
1418
def api_key(env) -> str:
@@ -52,3 +56,53 @@ def test_gemini_generation(animals_data: list[str], api_key: str) -> None:
5256
assert logs['retrieval']
5357
assert logs['augmented']
5458
assert logs['generation']
59+
60+
61+
@pytest.mark.skip_on_ci
62+
@pytest.mark.parametrize(
63+
'question,expected_answer',
64+
[
65+
('What animal is larger than a dinosaur?', 'Blue Whale'),
66+
(
67+
'What animal is renowned as the fastest animal on the planet?',
68+
'Peregrine Falcon',
69+
),
70+
],
71+
)
72+
def test_rag_gemini_structured_output(
73+
api_key: str,
74+
animals_data: list[str],
75+
question: str,
76+
expected_answer: str,
77+
) -> None:
78+
"""Test RAG pipeline with Gemini."""
79+
logs = {
80+
'retrieval': {},
81+
'augmented': {},
82+
'generation': {},
83+
}
84+
85+
rag = Rago(
86+
retrieval=StringRet(animals_data, logs=logs['retrieval']),
87+
augmented=SentenceTransformerAug(top_k=3, logs=logs['augmented']),
88+
generation=GeminiGen(
89+
api_key=api_key,
90+
model_name='gemini-1.5-flash',
91+
logs=logs['generation'],
92+
structured_output=AnimalModel,
93+
),
94+
)
95+
96+
result = cast(AnimalModel, rag.prompt(question))
97+
98+
error_message = (
99+
f'Expected response to mention `{expected_answer}`. '
100+
f'Result: `{result.name}`.'
101+
)
102+
103+
assert expected_answer == result.name, error_message
104+
105+
# check if logs have been used
106+
assert logs['retrieval']
107+
assert logs['augmented']
108+
assert logs['generation']

tests/test_openai.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@
22

33
import os
44

5+
from typing import cast
6+
57
import pytest
68

79
from rago import Rago
810
from rago.augmented import OpenAIAug
911
from rago.generation import OpenAIGen
1012
from rago.retrieval import StringRet
1113

14+
from .models import AnimalModel
15+
1216

1317
@pytest.fixture
1418
def api_key(env) -> str:
@@ -82,3 +86,53 @@ def test_rag_openai_gpt(animals_data: list[str], api_key: str) -> None:
8286
assert logs['retrieval']
8387
assert logs['augmented']
8488
assert logs['generation']
89+
90+
91+
@pytest.mark.skip_on_ci
92+
@pytest.mark.parametrize(
93+
'question,expected_answer',
94+
[
95+
('What animal is larger than a dinosaur?', 'Blue Whale'),
96+
(
97+
'What animal is renowned as the fastest animal on the planet?',
98+
'Peregrine Falcon',
99+
),
100+
],
101+
)
102+
def test_rag_openai_gpt_structured_output(
103+
api_key: str,
104+
animals_data: list[str],
105+
question: str,
106+
expected_answer: str,
107+
) -> None:
108+
"""Test RAG pipeline with OpenAI's GPT."""
109+
logs = {
110+
'retrieval': {},
111+
'augmented': {},
112+
'generation': {},
113+
}
114+
115+
rag = Rago(
116+
retrieval=StringRet(animals_data, logs=logs['retrieval']),
117+
augmented=OpenAIAug(api_key=api_key, top_k=3, logs=logs['augmented']),
118+
generation=OpenAIGen(
119+
api_key=api_key,
120+
model_name='gpt-3.5-turbo',
121+
logs=logs['generation'],
122+
structured_output=AnimalModel,
123+
),
124+
)
125+
126+
result = cast(AnimalModel, rag.prompt(question))
127+
128+
error_message = (
129+
f'Expected response to mention `{expected_answer}`. '
130+
f'Result: `{result.name}`.'
131+
)
132+
133+
assert expected_answer == result.name, error_message
134+
135+
# check if logs have been used
136+
assert logs['retrieval']
137+
assert logs['augmented']
138+
assert logs['generation']

0 commit comments

Comments
 (0)