Skip to content

Commit

Permalink
feat: Add support for extra parameters for RAG generation (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
xmnlab authored Jan 21, 2025
1 parent acd77b1 commit e7a57c1
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
10 changes: 10 additions & 0 deletions src/rago/generation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from abc import abstractmethod
from copy import deepcopy
from typing import Any, Optional, Type

import torch
Expand All @@ -14,6 +15,7 @@
from rago.extensions.cache import Cache

DEFAULT_LOGS: dict[str, Any] = {}
DEFAULT_API_PARAMS: dict[str, Any] = {}


@typechecked
Expand All @@ -31,6 +33,7 @@ class GenerationBase(RagoBase):
'question: \n```\n{query}\n```\ncontext: ```\n{context}\n```'
)
structured_output: Optional[Type[BaseModel]] = None
api_params: dict[str, Any] = {} # noqa: RUF012

# default parameters that can be overwritten by the derived class
default_device_name: str = 'cpu'
Expand All @@ -40,6 +43,7 @@ class GenerationBase(RagoBase):
default_prompt_template: str = (
'question: \n```\n{query}\n```\ncontext: ```\n{context}\n```'
)
default_api_params: dict[str, Any] = {} # noqa: RUF012

def __init__(
self,
Expand All @@ -49,6 +53,7 @@ def __init__(
output_max_length: int = 500,
device: str = 'auto',
structured_output: Optional[Type[BaseModel]] = None,
api_params: dict[str, Any] = DEFAULT_API_PARAMS,
api_key: str = '',
cache: Optional[Cache] = None,
logs: dict[str, Any] = DEFAULT_LOGS,
Expand All @@ -57,6 +62,7 @@ def __init__(
if logs is DEFAULT_LOGS:
logs = {}
super().__init__(api_key=api_key, cache=cache, logs=logs)

self.model_name: str = model_name or self.default_model_name
self.output_max_length: int = (
output_max_length or self.default_output_max_length
Expand All @@ -67,6 +73,10 @@ def __init__(
prompt_template or self.default_prompt_template
)
self.structured_output: Optional[Type[BaseModel]] = structured_output
if api_params is DEFAULT_API_PARAMS:
api_params = deepcopy(self.default_api_params or {})

self.api_params = api_params

if device not in ['cpu', 'cuda', 'auto']:
raise Exception(
Expand Down
13 changes: 10 additions & 3 deletions src/rago/generation/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ class OpenAIGen(GenerationBase):
"""OpenAI generation model for text generation."""

default_model_name = 'gpt-3.5-turbo'
default_api_params = { # noqa: RUF012
'top_p': 1.0,
'frequency_penalty': 0.0,
'presence_penalty': 0.0,
}

def _setup(self) -> None:
"""Set up the object with the initial parameters."""
Expand All @@ -40,14 +45,16 @@ def generate(
if not self.model:
raise Exception('The model was not created.')

api_params = (
self.api_params if self.api_params else self.default_api_params
)

model_params = dict(
model=self.model_name,
messages=[{'role': 'user', 'content': input_text}],
max_tokens=self.output_max_length,
temperature=self.temperature,
top_p=0.9,
frequency_penalty=0.5,
presence_penalty=0.3,
**api_params,
)

if self.structured_output:
Expand Down

0 comments on commit e7a57c1

Please sign in to comment.