Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add basic OpenAI model implemenation #26

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions llm_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# 1) model들을 등록할 전역 레지스트리 (dict)
MODEL_REGISTRY: Dict[str, Type[BaseModel]] = {}


# 2) 레지스트리에 등록할 헬퍼 함수
def register_model(name: str):
"""
Expand All @@ -13,13 +14,14 @@ def register_model(name: str):
class VLLMModel(BaseModel):
...
"""

def decorator(cls: Type[ModelType]):
if name in MODEL_REGISTRY:
raise ValueError(f"Model '{name}' already registered.")
MODEL_REGISTRY[name] = cls
return cls
return decorator

return decorator


# 3) 레지스트리에서 model 인스턴스를 생성하는 함수
Expand All @@ -28,12 +30,16 @@ def load_model(name: str, **kwargs) -> BaseModel:
문자열 name을 받아 해당 모델 클래스를 찾아 인스턴스화 후 반환.
"""
if name not in MODEL_REGISTRY:
raise ValueError(f"Unknown model: {name}. Please register it in MODEL_REGISTRY.")
raise ValueError(
f"Unknown model: {name}. Please register it in MODEL_REGISTRY."
)
model_cls = MODEL_REGISTRY[name]
return model_cls(**kwargs)


# 5) 실제 backend들 import -> 데코레이터로 등록
# from .vllm_backend import VLLMModel
# from .huggingface_backend import HFModel
# from .openai_backend import OpenAIModel
from .openai_backend import OpenAIModel

# from .multi_model import MultiModel
120 changes: 120 additions & 0 deletions llm_eval/models/openai_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import openai
import time
from typing import List, Dict, Any, Optional
from .base import BaseModel, register_model


@register_model("openai")
class OpenAIModel(BaseModel):
def __init__(
self,
api_key: str,
api_base: str = "https://api.openai.com/v1",
model_name: str = "gpt-4o", # gpt-4o-mini, o1, o1-mini 등
system_message: Optional[str] = None,
**kwargs,
):
super().__init__()
if not api_key:
raise ValueError("API key is required")

self._client = openai.Client(api_key=api_key, base_url=api_base)
self.model_name = model_name
self.system_message = system_message
self.default_params = kwargs

def _create_payload(
self,
inputs: Union[str, List[Dict]],
return_logits: bool = False,
**kwargs,
) -> Dict[str, Any]:
params = self.default_params.copy()
params.update(kwargs)

payload = {"model": self.model_name}

if not self.model_name.startswith("gpt"):
payload = {"model": self.model_name, "prompt": inputs, **params}
if return_logits:
payload["logprobs"] = 5

else:
messages = []
if self.system_message:
messages.append({"role": "system", "content": self.system_message})
if isinstance(inputs, str):
messages.append({"role": "user", "content": inputs})
else:
messages.extend(inputs)
payload["messages"] = messages

for param in [
"max_tokens",
"temperature",
"top_p",
"frequency_penalty",
"presence_penalty",
]:
if param in params:
payload[param] = params[param]

return {k: v for k, v in payload.items() if v is not None}

def generate_batch(
self,
inputs: List[Dict[str, Any]],
return_logits: bool = False,
raise_error: bool = False,
max_retries: int = 3,
**kwargs,
) -> List[Dict[str, Any]]:
outputs = []

for input_item in inputs:
item = input_item.copy()
result = None

for attempt in range(max_retries):
try:
payload = self._create_payload(
item["input"],
return_logits=return_logits,
**kwargs,
)

if not self.model_name.startswith("gpt"):
response = self._client.completions.create(**payload)
result = {
"prediction": response.choices[0].text,
}
if return_logits:
result.update(
{
"logprobs": response.choices[
0
].logprobs.token_logprobs,
"tokens": response.choices[0].logprobs.tokens,
}
)
else:
response = self._client.chat.completions.create(**payload)
result = {
"prediction": response.choices[0].message.content,
}
if return_logits and hasattr(response.choices[0], "logprobs"):
result["logprobs"] = response.choices[0].logprobs

break

except Exception as e:
if attempt == max_retries - 1:
if raise_error:
raise
item["error"] = str(e)
else:
time.sleep(1 * (attempt + 1))

outputs.append(**item, **(result or {"error": "Failed to generate"}))

return outputs
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pre_commit==4.0.1
transformers>=4.0.0
torch>=2.0.0
pytest>=7.3.0
openai>=1.0.0