From 3a905d1d98b9eccc465bbd59abaa0eaa281c134a Mon Sep 17 00:00:00 2001 From: ionic-bond Date: Wed, 25 Dec 2024 00:39:44 +0800 Subject: [PATCH] Add API key pool --- stream_translator_gpt/audio_transcriber.py | 8 +++-- stream_translator_gpt/common.py | 34 ++++++++++++++++++++++ stream_translator_gpt/llm_translator.py | 4 ++- stream_translator_gpt/translator.py | 12 ++------ 4 files changed, 44 insertions(+), 14 deletions(-) diff --git a/stream_translator_gpt/audio_transcriber.py b/stream_translator_gpt/audio_transcriber.py index 2c1b2af..b76b5c2 100644 --- a/stream_translator_gpt/audio_transcriber.py +++ b/stream_translator_gpt/audio_transcriber.py @@ -6,7 +6,7 @@ from openai import OpenAI, DefaultHttpxClient from . import filters -from .common import TranslationTask, SAMPLE_RATE, LoopWorkerBase, sec2str +from .common import TranslationTask, SAMPLE_RATE, LoopWorkerBase, sec2str, ApiKeyPool TEMP_AUDIO_FILE_NAME = '_whisper_api_temp.wav' @@ -72,7 +72,7 @@ class RemoteOpenaiWhisper(OpenaiWhisper): # https://platform.openai.com/docs/api-reference/audio/createTranscription?lang=python def __init__(self, language: str, proxy: str) -> None: - self.client = OpenAI(http_client=DefaultHttpxClient(proxy=proxy)) + self.proxy = proxy self.language = language def __del__(self): @@ -83,7 +83,9 @@ def transcribe(self, audio: np.array, **transcribe_options) -> str: with open(TEMP_AUDIO_FILE_NAME, 'wb') as audio_file: write_audio(audio_file, SAMPLE_RATE, audio) with open(TEMP_AUDIO_FILE_NAME, 'rb') as audio_file: - result = self.client.audio.transcriptions.create(model='whisper-1', file=audio_file, + ApiKeyPool.use_openai_api() + client = OpenAI(http_client=DefaultHttpxClient(proxy=self.proxy)) + result = client.audio.transcriptions.create(model='whisper-1', file=audio_file, language=self.language).text os.remove(TEMP_AUDIO_FILE_NAME) return result diff --git a/stream_translator_gpt/common.py b/stream_translator_gpt/common.py index 03d2016..dbe4a83 100644 --- a/stream_translator_gpt/common.py +++ b/stream_translator_gpt/common.py @@ -1,7 +1,10 @@ +import os from abc import ABC, abstractmethod from datetime import datetime +import google.generativeai as genai import numpy as np +from google.api_core.client_options import ClientOptions from whisper.audio import SAMPLE_RATE @@ -38,3 +41,34 @@ def sec2str(second: float): result = dt.strftime('%H:%M:%S') result += ',' + str(int(second * 10 % 10)) return result + + +class ApiKeyPool(): + + @classmethod + def init(cls, openai_api_key, gpt_base_url, google_api_key, gemini_base_url): + if gpt_base_url: + os.environ['OPENAI_BASE_URL'] = gpt_base_url + cls.gemini_base_url = gemini_base_url + + cls.openai_api_key_list = [key.strip() for key in openai_api_key.split(',')] if openai_api_key else None + cls.openai_api_key_index = 0 + cls.use_openai_api() + cls.google_api_key_list = [key.strip() for key in google_api_key.split(',')] if google_api_key else None + cls.google_api_key_index = 0 + cls.use_google_api() + + @classmethod + def use_openai_api(cls): + if not cls.openai_api_key_list: + return + os.environ['OPENAI_API_KEY'] = cls.openai_api_key_list[cls.openai_api_key_index] + cls.openai_api_key_index = (cls.openai_api_key_index + 1) % len(cls.openai_api_key_list) + + @classmethod + def use_google_api(cls): + if not cls.google_api_key_list: + return + gemini_client_options = ClientOptions(api_endpoint=cls.gemini_base_url) + genai.configure(api_key=cls.google_api_key_list[cls.google_api_key_index], client_options=gemini_client_options, transport='rest') + cls.google_api_key_index = (cls.google_api_key_index + 1) % len(cls.google_api_key_list) diff --git a/stream_translator_gpt/llm_translator.py b/stream_translator_gpt/llm_translator.py index 2706f2c..7ebb1de 100644 --- a/stream_translator_gpt/llm_translator.py +++ b/stream_translator_gpt/llm_translator.py @@ -11,7 +11,7 @@ from google.generativeai.types import HarmCategory, HarmBlockThreshold from openai import OpenAI, DefaultHttpxClient, APITimeoutError, APIConnectionError -from .common import TranslationTask, LoopWorkerBase +from .common import TranslationTask, LoopWorkerBase, ApiKeyPool # The double quotes in the values of JSON have not been escaped, so manual escaping is necessary. @@ -87,6 +87,7 @@ def _append_history_message(self, user_content: str, assistant_content: str): def _translate_by_gpt(self, translation_task: TranslationTask): # https://platform.openai.com/docs/api-reference/chat/create?lang=python + ApiKeyPool.use_openai_api() client = OpenAI(http_client=DefaultHttpxClient(proxy=self.proxy)) system_prompt = 'You are a translation engine.' if self.use_json_result: @@ -130,6 +131,7 @@ def _gpt_to_gemini(gpt_messages: list): def _translate_by_gemini(self, translation_task: TranslationTask): # https://ai.google.dev/tutorials/python_quickstart + ApiKeyPool.use_google_api() client = genai.GenerativeModel(self.model) messages = self._gpt_to_gemini(self.history_messages) user_content = '{}: \n{}'.format(self.prompt, translation_task.transcribed_text) diff --git a/stream_translator_gpt/translator.py b/stream_translator_gpt/translator.py index c4ef7fa..878c0c3 100644 --- a/stream_translator_gpt/translator.py +++ b/stream_translator_gpt/translator.py @@ -5,9 +5,7 @@ import threading import time -import google.generativeai as genai -from google.api_core.client_options import ClientOptions - +from .common import ApiKeyPool from .audio_getter import StreamAudioGetter, LocalFileAudioGetter, DeviceAudioGetter from .audio_slicer import AudioSlicer from .audio_transcriber import OpenaiWhisper, FasterWhisper, RemoteOpenaiWhisper @@ -28,13 +26,7 @@ def main(url, format, cookies, input_proxy, device_index, device_recording_inter gpt_base_url, gemini_base_url, processing_proxy, use_json_result, retry_if_translation_fails, output_timestamps, hide_transcribe_result, output_proxy, output_file_path, cqhttp_url, cqhttp_token, discord_webhook_url, telegram_token, telegram_chat_id, **transcribe_options): - if openai_api_key: - os.environ['OPENAI_API_KEY'] = openai_api_key - if gpt_base_url: - os.environ['OPENAI_BASE_URL'] = gpt_base_url - if google_api_key: - gemini_client_options = ClientOptions(api_endpoint=gemini_base_url) - genai.configure(api_key=google_api_key, client_options=gemini_client_options, transport='rest') + ApiKeyPool.init(openai_api_key=openai_api_key, gpt_base_url=gpt_base_url, google_api_key=google_api_key, gemini_base_url=gemini_base_url) getter_to_slicer_queue = queue.SimpleQueue() slicer_to_transcriber_queue = queue.SimpleQueue()