From 03355cdc3ada07c55646b39fe0670b9c0afa7e91 Mon Sep 17 00:00:00 2001 From: "deling.sc" Date: Wed, 12 Jun 2024 11:54:09 +0800 Subject: [PATCH] feat: add openai sdk (for service api) support. --- llmuses/models/model_adapter.py | 78 +++++++++++++++++---------------- llmuses/models/openai_model.py | 45 ++++++++++++++++--- llmuses/run.py | 48 ++++++++++++++------ 3 files changed, 115 insertions(+), 56 deletions(-) diff --git a/llmuses/models/model_adapter.py b/llmuses/models/model_adapter.py index 0dfae4dd..e359e33c 100644 --- a/llmuses/models/model_adapter.py +++ b/llmuses/models/model_adapter.py @@ -16,6 +16,7 @@ from llmuses.models.template import get_template, StopWordsCriteria, TemplateType, fuzzy_match from llmuses.utils.logger import get_logger from transformers import StoppingCriteriaList +from llmuses.models.openai_model import OpenAIModel logger = get_logger() @@ -413,7 +414,7 @@ def __init__(self, self.generation_config = model_cfg.get('generation_config', None) self.generation_template = model_cfg.get('generation_template', None) - if self.generation_config is None or self.generation_template is None: + if type(model) is not OpenAIModel and (self.generation_config is None or self.generation_template is None): raise ValueError('generation_config or generation_template is required for chat generation.') super().__init__(model=model, tokenizer=tokenizer, model_cfg=model_cfg) @@ -423,42 +424,45 @@ def _model_generate(self, query: str, infer_cfg: dict) -> str: history=[], system=None) - inputs, _ = self.generation_template.encode(example) - input_ids = inputs['input_ids'] - input_ids = torch.tensor(input_ids)[None].to(self.device) - attention_mask = torch.ones_like(input_ids).to(self.device) - - # Process infer_cfg - infer_cfg = infer_cfg or {} - if isinstance(infer_cfg.get('num_return_sequences'), int) and infer_cfg['num_return_sequences'] > 1: - infer_cfg['do_sample'] = True - - # TODO: stop settings - stop = infer_cfg.get('stop', None) - eos_token_id = self.tokenizer.encode(stop, add_special_tokens=False)[0] \ - if stop else self.tokenizer.eos_token_id - - if eos_token_id is not None: - infer_cfg['eos_token_id'] = eos_token_id - infer_cfg['pad_token_id'] = eos_token_id # setting eos_token_id as pad token - - - self.generation_config.update(**infer_cfg) - - # stopping - stop_words = [self.generation_template.suffix[-1]] - decode_kwargs = {} - stopping_criteria = StoppingCriteriaList( - [StopWordsCriteria(self.tokenizer, stop_words, **decode_kwargs)]) - - # Run inference - output_ids = self.model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - generation_config=self.generation_config, - stopping_criteria=stopping_criteria, ) - - response = self.tokenizer.decode(output_ids[0, len(input_ids[0]):], True, **decode_kwargs) + if type(self.model) is OpenAIModel: + response = self.model.completion(query) + else: + inputs, _ = self.generation_template.encode(example) + input_ids = inputs['input_ids'] + input_ids = torch.tensor(input_ids)[None].to(self.device) + attention_mask = torch.ones_like(input_ids).to(self.device) + + # Process infer_cfg + infer_cfg = infer_cfg or {} + if isinstance(infer_cfg.get('num_return_sequences'), int) and infer_cfg['num_return_sequences'] > 1: + infer_cfg['do_sample'] = True + + # TODO: stop settings + stop = infer_cfg.get('stop', None) + eos_token_id = self.tokenizer.encode(stop, add_special_tokens=False)[0] \ + if stop else self.tokenizer.eos_token_id + + if eos_token_id is not None: + infer_cfg['eos_token_id'] = eos_token_id + infer_cfg['pad_token_id'] = eos_token_id # setting eos_token_id as pad token + + + self.generation_config.update(**infer_cfg) + + # stopping + stop_words = [self.generation_template.suffix[-1]] + decode_kwargs = {} + stopping_criteria = StoppingCriteriaList( + [StopWordsCriteria(self.tokenizer, stop_words, **decode_kwargs)]) + + # Run inference + output_ids = self.model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + generation_config=self.generation_config, + stopping_criteria=stopping_criteria, ) + + response = self.tokenizer.decode(output_ids[0, len(input_ids[0]):], True, **decode_kwargs) return response @torch.no_grad() diff --git a/llmuses/models/openai_model.py b/llmuses/models/openai_model.py index caa11ac1..35f7d665 100644 --- a/llmuses/models/openai_model.py +++ b/llmuses/models/openai_model.py @@ -3,7 +3,7 @@ import os import time -import openai +from openai import OpenAI from llmuses.models import ChatBaseModel from llmuses.utils.logger import get_logger @@ -19,13 +19,16 @@ class OpenAIModel(ChatBaseModel): MAX_RETRIES = 3 - def __init__(self, model_cfg: dict, **kwargs): + def __init__(self, model_id: str, model_cfg: dict, base_url: str, api_key: str, **kwargs): super(OpenAIModel, self).__init__(model_cfg=model_cfg, **kwargs) - openai_api_key = os.environ.get('OPENAI_API_KEY', None) - self.api_key = self.model_cfg.get('api_key', openai_api_key) + self.client = OpenAI( + api_key=api_key, + base_url=base_url + ) + self.model_id = model_id - if not self.api_key: + if not api_key: logger.error('OpenAI API key is not provided, please set it in environment variable OPENAI_API_KEY') # raise ValueError( # 'OpenAI API key is not provided, ' @@ -52,6 +55,35 @@ def predict(self, model_id: str, inputs: dict, **kwargs) -> dict: return res + def completion(self, query: str) -> str: + predictions = "" + for i in range(self.MAX_RETRIES): + try: + resp = self.client.chat.completions.create( + model=self.model_id, + messages=[{ + "role": "user", + "content": query + }], + ) + if resp: + predictions = resp.choices[0].message.content + else: + logger.warning( + f'OpenAI GPT API call failed: got empty response ' + f'for input {query}') + predictions = "" + + return predictions + + except Exception as e: + logger.warning(f'OpenAI API call failed: {e}') + time.sleep(3) + + logger.error( + f'OpenAI API call failed after {self.MAX_RETRIES} retries') + return predictions + def _predict(self, model_id, sys_prompt, @@ -101,3 +133,6 @@ def _predict(self, logger.error( f'OpenAI API call failed after {self.MAX_RETRIES} retries') return res + + def mock_predict(self): + return "TEST 测试一下" \ No newline at end of file diff --git a/llmuses/run.py b/llmuses/run.py index 6d34d78f..27b32f4c 100644 --- a/llmuses/run.py +++ b/llmuses/run.py @@ -13,6 +13,7 @@ from llmuses.constants import OutputsStructure from llmuses.tools.combine_reports import ReportsRecorder from llmuses.models import load_model +from llmuses.models.openai_model import OpenAIModel import os logger = get_logger() @@ -32,6 +33,16 @@ def parse_args(): help='The model id on modelscope, or local model dir.', type=str, required=True) + parser.add_argument('--base-url', + help='The base url for openai, or other model api service.', + type=str, + required=False, + default="") + parser.add_argument('--api-key', + help='The api key for openai, or other model api service.', + type=str, + required=False, + default="") parser.add_argument('--template-type', type=str, help='The template type for generation, should be a string.' @@ -181,20 +192,29 @@ def main(): datasets_list = args.datasets if not args.dry_run: - model, tokenizer, model_cfg = load_model(model_id=model_id, - device_map=model_args.get("device_map", "auto"), - torch_dtype=model_precision, - model_revision=model_revision, - cache_dir=args.work_dir, - template_type=template_type, - ) - qwen_model, qwen_tokenizer, qwen_model_cfg = load_model(model_id=qwen_model_id, - device_map=model_args.get("device_map", "auto"), - torch_dtype=model_precision, - model_revision=None, - cache_dir=args.work_dir, - template_type=template_type, - ) if len(qwen_model_id) > 0 else (None, None, None) + base_url = args.base_url + api_key = args.api_key + if base_url and api_key: + tokenizer, model_cfg = None, {"model_id": model_id} + model = OpenAIModel(model_id, + model_cfg=model_cfg, + base_url=base_url, + api_key=api_key) + else: + model, tokenizer, model_cfg = load_model(model_id=model_id, + device_map=model_args.get("device_map", "auto"), + torch_dtype=model_precision, + model_revision=model_revision, + cache_dir=args.work_dir, + template_type=template_type, + ) + qwen_model, qwen_tokenizer, qwen_model_cfg = load_model(model_id=qwen_model_id, + device_map=model_args.get("device_map", "auto"), + torch_dtype=model_precision, + model_revision=None, + cache_dir=args.work_dir, + template_type=template_type, + ) if len(qwen_model_id) > 0 else (None, None, None) else: logger.warning('Dry run mode, will use dummy model.') model, tokenizer, model_cfg = None, None, None