Skip to content

Commit

Permalink
Add API key pool
Browse files Browse the repository at this point in the history
  • Loading branch information
ionic-bond committed Dec 24, 2024
1 parent 8a298d6 commit 3a905d1
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 14 deletions.
8 changes: 5 additions & 3 deletions stream_translator_gpt/audio_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from openai import OpenAI, DefaultHttpxClient

from . import filters
from .common import TranslationTask, SAMPLE_RATE, LoopWorkerBase, sec2str
from .common import TranslationTask, SAMPLE_RATE, LoopWorkerBase, sec2str, ApiKeyPool

TEMP_AUDIO_FILE_NAME = '_whisper_api_temp.wav'

Expand Down Expand Up @@ -72,7 +72,7 @@ class RemoteOpenaiWhisper(OpenaiWhisper):
# https://platform.openai.com/docs/api-reference/audio/createTranscription?lang=python

def __init__(self, language: str, proxy: str) -> None:
self.client = OpenAI(http_client=DefaultHttpxClient(proxy=proxy))
self.proxy = proxy
self.language = language

def __del__(self):
Expand All @@ -83,7 +83,9 @@ def transcribe(self, audio: np.array, **transcribe_options) -> str:
with open(TEMP_AUDIO_FILE_NAME, 'wb') as audio_file:
write_audio(audio_file, SAMPLE_RATE, audio)
with open(TEMP_AUDIO_FILE_NAME, 'rb') as audio_file:
result = self.client.audio.transcriptions.create(model='whisper-1', file=audio_file,
ApiKeyPool.use_openai_api()
client = OpenAI(http_client=DefaultHttpxClient(proxy=self.proxy))
result = client.audio.transcriptions.create(model='whisper-1', file=audio_file,
language=self.language).text
os.remove(TEMP_AUDIO_FILE_NAME)
return result
34 changes: 34 additions & 0 deletions stream_translator_gpt/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
from abc import ABC, abstractmethod
from datetime import datetime

import google.generativeai as genai
import numpy as np
from google.api_core.client_options import ClientOptions
from whisper.audio import SAMPLE_RATE


Expand Down Expand Up @@ -38,3 +41,34 @@ def sec2str(second: float):
result = dt.strftime('%H:%M:%S')
result += ',' + str(int(second * 10 % 10))
return result


class ApiKeyPool():

@classmethod
def init(cls, openai_api_key, gpt_base_url, google_api_key, gemini_base_url):
if gpt_base_url:
os.environ['OPENAI_BASE_URL'] = gpt_base_url
cls.gemini_base_url = gemini_base_url

cls.openai_api_key_list = [key.strip() for key in openai_api_key.split(',')] if openai_api_key else None
cls.openai_api_key_index = 0
cls.use_openai_api()
cls.google_api_key_list = [key.strip() for key in google_api_key.split(',')] if google_api_key else None
cls.google_api_key_index = 0
cls.use_google_api()

@classmethod
def use_openai_api(cls):
if not cls.openai_api_key_list:
return
os.environ['OPENAI_API_KEY'] = cls.openai_api_key_list[cls.openai_api_key_index]
cls.openai_api_key_index = (cls.openai_api_key_index + 1) % len(cls.openai_api_key_list)

@classmethod
def use_google_api(cls):
if not cls.google_api_key_list:
return
gemini_client_options = ClientOptions(api_endpoint=cls.gemini_base_url)
genai.configure(api_key=cls.google_api_key_list[cls.google_api_key_index], client_options=gemini_client_options, transport='rest')
cls.google_api_key_index = (cls.google_api_key_index + 1) % len(cls.google_api_key_list)
4 changes: 3 additions & 1 deletion stream_translator_gpt/llm_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from google.generativeai.types import HarmCategory, HarmBlockThreshold
from openai import OpenAI, DefaultHttpxClient, APITimeoutError, APIConnectionError

from .common import TranslationTask, LoopWorkerBase
from .common import TranslationTask, LoopWorkerBase, ApiKeyPool


# The double quotes in the values of JSON have not been escaped, so manual escaping is necessary.
Expand Down Expand Up @@ -87,6 +87,7 @@ def _append_history_message(self, user_content: str, assistant_content: str):

def _translate_by_gpt(self, translation_task: TranslationTask):
# https://platform.openai.com/docs/api-reference/chat/create?lang=python
ApiKeyPool.use_openai_api()
client = OpenAI(http_client=DefaultHttpxClient(proxy=self.proxy))
system_prompt = 'You are a translation engine.'
if self.use_json_result:
Expand Down Expand Up @@ -130,6 +131,7 @@ def _gpt_to_gemini(gpt_messages: list):

def _translate_by_gemini(self, translation_task: TranslationTask):
# https://ai.google.dev/tutorials/python_quickstart
ApiKeyPool.use_google_api()
client = genai.GenerativeModel(self.model)
messages = self._gpt_to_gemini(self.history_messages)
user_content = '{}: \n{}'.format(self.prompt, translation_task.transcribed_text)
Expand Down
12 changes: 2 additions & 10 deletions stream_translator_gpt/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
import threading
import time

import google.generativeai as genai
from google.api_core.client_options import ClientOptions

from .common import ApiKeyPool
from .audio_getter import StreamAudioGetter, LocalFileAudioGetter, DeviceAudioGetter
from .audio_slicer import AudioSlicer
from .audio_transcriber import OpenaiWhisper, FasterWhisper, RemoteOpenaiWhisper
Expand All @@ -28,13 +26,7 @@ def main(url, format, cookies, input_proxy, device_index, device_recording_inter
gpt_base_url, gemini_base_url, processing_proxy, use_json_result, retry_if_translation_fails,
output_timestamps, hide_transcribe_result, output_proxy, output_file_path, cqhttp_url, cqhttp_token,
discord_webhook_url, telegram_token, telegram_chat_id, **transcribe_options):
if openai_api_key:
os.environ['OPENAI_API_KEY'] = openai_api_key
if gpt_base_url:
os.environ['OPENAI_BASE_URL'] = gpt_base_url
if google_api_key:
gemini_client_options = ClientOptions(api_endpoint=gemini_base_url)
genai.configure(api_key=google_api_key, client_options=gemini_client_options, transport='rest')
ApiKeyPool.init(openai_api_key=openai_api_key, gpt_base_url=gpt_base_url, google_api_key=google_api_key, gemini_base_url=gemini_base_url)

getter_to_slicer_queue = queue.SimpleQueue()
slicer_to_transcriber_queue = queue.SimpleQueue()
Expand Down

0 comments on commit 3a905d1

Please sign in to comment.