Skip to content

Commit

Permalink
feat: add openai sdk (for service api) support.
Browse files Browse the repository at this point in the history
  • Loading branch information
deling.sc committed Jun 12, 2024
1 parent 3d30e0b commit 03355cd
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 56 deletions.
78 changes: 41 additions & 37 deletions llmuses/models/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from llmuses.models.template import get_template, StopWordsCriteria, TemplateType, fuzzy_match
from llmuses.utils.logger import get_logger
from transformers import StoppingCriteriaList
from llmuses.models.openai_model import OpenAIModel

logger = get_logger()

Expand Down Expand Up @@ -413,7 +414,7 @@ def __init__(self,
self.generation_config = model_cfg.get('generation_config', None)
self.generation_template = model_cfg.get('generation_template', None)

if self.generation_config is None or self.generation_template is None:
if type(model) is not OpenAIModel and (self.generation_config is None or self.generation_template is None):
raise ValueError('generation_config or generation_template is required for chat generation.')

super().__init__(model=model, tokenizer=tokenizer, model_cfg=model_cfg)
Expand All @@ -423,42 +424,45 @@ def _model_generate(self, query: str, infer_cfg: dict) -> str:
history=[],
system=None)

inputs, _ = self.generation_template.encode(example)
input_ids = inputs['input_ids']
input_ids = torch.tensor(input_ids)[None].to(self.device)
attention_mask = torch.ones_like(input_ids).to(self.device)

# Process infer_cfg
infer_cfg = infer_cfg or {}
if isinstance(infer_cfg.get('num_return_sequences'), int) and infer_cfg['num_return_sequences'] > 1:
infer_cfg['do_sample'] = True

# TODO: stop settings
stop = infer_cfg.get('stop', None)
eos_token_id = self.tokenizer.encode(stop, add_special_tokens=False)[0] \
if stop else self.tokenizer.eos_token_id

if eos_token_id is not None:
infer_cfg['eos_token_id'] = eos_token_id
infer_cfg['pad_token_id'] = eos_token_id # setting eos_token_id as pad token


self.generation_config.update(**infer_cfg)

# stopping
stop_words = [self.generation_template.suffix[-1]]
decode_kwargs = {}
stopping_criteria = StoppingCriteriaList(
[StopWordsCriteria(self.tokenizer, stop_words, **decode_kwargs)])

# Run inference
output_ids = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
generation_config=self.generation_config,
stopping_criteria=stopping_criteria, )

response = self.tokenizer.decode(output_ids[0, len(input_ids[0]):], True, **decode_kwargs)
if type(self.model) is OpenAIModel:
response = self.model.completion(query)
else:
inputs, _ = self.generation_template.encode(example)
input_ids = inputs['input_ids']
input_ids = torch.tensor(input_ids)[None].to(self.device)
attention_mask = torch.ones_like(input_ids).to(self.device)

# Process infer_cfg
infer_cfg = infer_cfg or {}
if isinstance(infer_cfg.get('num_return_sequences'), int) and infer_cfg['num_return_sequences'] > 1:
infer_cfg['do_sample'] = True

# TODO: stop settings
stop = infer_cfg.get('stop', None)
eos_token_id = self.tokenizer.encode(stop, add_special_tokens=False)[0] \
if stop else self.tokenizer.eos_token_id

if eos_token_id is not None:
infer_cfg['eos_token_id'] = eos_token_id
infer_cfg['pad_token_id'] = eos_token_id # setting eos_token_id as pad token


self.generation_config.update(**infer_cfg)

# stopping
stop_words = [self.generation_template.suffix[-1]]
decode_kwargs = {}
stopping_criteria = StoppingCriteriaList(
[StopWordsCriteria(self.tokenizer, stop_words, **decode_kwargs)])

# Run inference
output_ids = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
generation_config=self.generation_config,
stopping_criteria=stopping_criteria, )

response = self.tokenizer.decode(output_ids[0, len(input_ids[0]):], True, **decode_kwargs)
return response

@torch.no_grad()
Expand Down
45 changes: 40 additions & 5 deletions llmuses/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import time

import openai
from openai import OpenAI

from llmuses.models import ChatBaseModel
from llmuses.utils.logger import get_logger
Expand All @@ -19,13 +19,16 @@ class OpenAIModel(ChatBaseModel):

MAX_RETRIES = 3

def __init__(self, model_cfg: dict, **kwargs):
def __init__(self, model_id: str, model_cfg: dict, base_url: str, api_key: str, **kwargs):
super(OpenAIModel, self).__init__(model_cfg=model_cfg, **kwargs)

openai_api_key = os.environ.get('OPENAI_API_KEY', None)
self.api_key = self.model_cfg.get('api_key', openai_api_key)
self.client = OpenAI(
api_key=api_key,
base_url=base_url
)
self.model_id = model_id

if not self.api_key:
if not api_key:
logger.error('OpenAI API key is not provided, please set it in environment variable OPENAI_API_KEY')
# raise ValueError(
# 'OpenAI API key is not provided, '
Expand All @@ -52,6 +55,35 @@ def predict(self, model_id: str, inputs: dict, **kwargs) -> dict:

return res

def completion(self, query: str) -> str:
predictions = ""
for i in range(self.MAX_RETRIES):
try:
resp = self.client.chat.completions.create(
model=self.model_id,
messages=[{
"role": "user",
"content": query
}],
)
if resp:
predictions = resp.choices[0].message.content
else:
logger.warning(
f'OpenAI GPT API call failed: got empty response '
f'for input {query}')
predictions = ""

return predictions

except Exception as e:
logger.warning(f'OpenAI API call failed: {e}')
time.sleep(3)

logger.error(
f'OpenAI API call failed after {self.MAX_RETRIES} retries')
return predictions

def _predict(self,
model_id,
sys_prompt,
Expand Down Expand Up @@ -101,3 +133,6 @@ def _predict(self,
logger.error(
f'OpenAI API call failed after {self.MAX_RETRIES} retries')
return res

def mock_predict(self):
return "TEST 测试一下"
48 changes: 34 additions & 14 deletions llmuses/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from llmuses.constants import OutputsStructure
from llmuses.tools.combine_reports import ReportsRecorder
from llmuses.models import load_model
from llmuses.models.openai_model import OpenAIModel
import os

logger = get_logger()
Expand All @@ -32,6 +33,16 @@ def parse_args():
help='The model id on modelscope, or local model dir.',
type=str,
required=True)
parser.add_argument('--base-url',
help='The base url for openai, or other model api service.',
type=str,
required=False,
default="")
parser.add_argument('--api-key',
help='The api key for openai, or other model api service.',
type=str,
required=False,
default="")
parser.add_argument('--template-type',
type=str,
help='The template type for generation, should be a string.'
Expand Down Expand Up @@ -181,20 +192,29 @@ def main():
datasets_list = args.datasets

if not args.dry_run:
model, tokenizer, model_cfg = load_model(model_id=model_id,
device_map=model_args.get("device_map", "auto"),
torch_dtype=model_precision,
model_revision=model_revision,
cache_dir=args.work_dir,
template_type=template_type,
)
qwen_model, qwen_tokenizer, qwen_model_cfg = load_model(model_id=qwen_model_id,
device_map=model_args.get("device_map", "auto"),
torch_dtype=model_precision,
model_revision=None,
cache_dir=args.work_dir,
template_type=template_type,
) if len(qwen_model_id) > 0 else (None, None, None)
base_url = args.base_url
api_key = args.api_key
if base_url and api_key:
tokenizer, model_cfg = None, {"model_id": model_id}
model = OpenAIModel(model_id,
model_cfg=model_cfg,
base_url=base_url,
api_key=api_key)
else:
model, tokenizer, model_cfg = load_model(model_id=model_id,
device_map=model_args.get("device_map", "auto"),
torch_dtype=model_precision,
model_revision=model_revision,
cache_dir=args.work_dir,
template_type=template_type,
)
qwen_model, qwen_tokenizer, qwen_model_cfg = load_model(model_id=qwen_model_id,
device_map=model_args.get("device_map", "auto"),
torch_dtype=model_precision,
model_revision=None,
cache_dir=args.work_dir,
template_type=template_type,
) if len(qwen_model_id) > 0 else (None, None, None)
else:
logger.warning('Dry run mode, will use dummy model.')
model, tokenizer, model_cfg = None, None, None
Expand Down

0 comments on commit 03355cd

Please sign in to comment.