Skip to content

Commit

Permalink
refactor: class-based design (#15)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Tohrusky <65994850+Tohrusky@users.noreply.github.com>
  • Loading branch information
NULL204 and Tohrusky authored Dec 1, 2024
1 parent b67c9c6 commit f991615
Show file tree
Hide file tree
Showing 12 changed files with 196 additions and 79 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/CI-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
CI:
strategy:
matrix:
os-version: ["ubuntu-20.04", "macos-13", "windows-latest"]
os-version: ["ubuntu-20.04", "windows-latest", "macos-13"]
python-version: ["3.9"]
poetry-version: ["1.8.3"]

Expand All @@ -48,7 +48,7 @@ jobs:
- name: Test
run: |
pip install numpy==1.26.4
pip install pre-commit pytest mypy ruff types-requests pytest-cov pytest-asyncio coverage pydantic openai openai-whisper requests beautifulsoup4 tenacity pysubs2
pip install pre-commit pytest mypy ruff types-requests pytest-cov pytest-asyncio coverage pydantic openai openai-whisper httpx tenacity pysubs2
make lint
make test
Expand Down
41 changes: 15 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,36 +40,25 @@ yuisub -h # Displays help message
```python3
import asyncio

from yuisub import translate, bilingual, load
from yuisub.a2t import WhisperModel
from yuisub import SubtitleTranslator

# use an asynchronous environment
# Using an asynchronous environment
async def main() -> None:

# sub from audio
model = WhisperModel(name="medium", device="cuda")
sub = model.transcribe(audio="path/to/audio.mp3")

# sub from file
# sub = load("path/to/input.srt")

# generate bilingual subtitle
sub_zh = await translate(
sub=sub,
model="gpt_model_name",
api_key="your_openai_api_key",
base_url="api_url",
bangumi_url="https://bangumi.tv/subject/424883/"
)

sub_bilingual = await bilingual(
sub_origin=sub,
sub_zh=sub_zh
translator = SubtitleTranslator(
# if you wanna use audio input
# torch_device='cuda',
# whisper_model='medium',

model='gpt_model_name',
api_key='your_openai_api_key',
base_url='api_url',
bangumi_url='https://bangumi.tv/subject/424883/',
bangumi_access_token='your_bangumi_token',
)

# save the ASS files
sub_zh.save("path/to/output.zh.ass")
sub_bilingual.save("path/to/output.bilingual.ass")
sub_zh, sub_bilingual = await translator.get_subtitles(sub='path/to/sub.srt') # Or audio='path/to/audio.mp3',
sub_zh.save('path/to/output_zh.ass')
sub_bilingual.save('path/to/output_bilingual.ass')

asyncio.run(main())
```
Expand Down
4 changes: 3 additions & 1 deletion tests/test_bangumi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from yuisub import bangumi

from . import util


async def test_bangumi() -> None:
url_list = [
Expand All @@ -9,6 +11,6 @@ async def test_bangumi() -> None:
]

for url in url_list:
r = await bangumi(url)
r = await bangumi(url=url, token=util.BANGUMI_ACCESS_TOKEN)
print(r.introduction)
print(r.characters)
9 changes: 5 additions & 4 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import pytest

from tests import util
from yuisub import ORIGIN, Summarizer, Translator, bangumi

from . import util

origin = ORIGIN(
origin="何だよ…けっこう多いじゃねぇか",
)
Expand Down Expand Up @@ -65,7 +66,7 @@ async def test_llm_bangumi() -> None:
model=util.OPENAI_MODEL,
api_key=util.OPENAI_API_KEY,
base_url=util.OPENAI_BASE_URL,
bangumi_info=await bangumi(util.BANGUMI_URL),
bangumi_info=await bangumi(url=util.BANGUMI_URL, token=util.BANGUMI_ACCESS_TOKEN),
)
print(t.system_prompt)
res = await t.ask(origin)
Expand All @@ -78,7 +79,7 @@ async def test_llm_bangumi_2() -> None:
model=util.OPENAI_MODEL,
api_key=util.OPENAI_API_KEY,
base_url=util.OPENAI_BASE_URL,
bangumi_info=await bangumi(util.BANGUMI_URL),
bangumi_info=await bangumi(url=util.BANGUMI_URL, token=util.BANGUMI_ACCESS_TOKEN),
)
print(t.system_prompt)
s = ORIGIN(
Expand All @@ -95,7 +96,7 @@ async def test_llm_summary() -> None:
model=util.OPENAI_MODEL,
api_key=util.OPENAI_API_KEY,
base_url=util.OPENAI_BASE_URL,
bangumi_info=await bangumi(util.BANGUMI_URL),
bangumi_info=await bangumi(url=util.BANGUMI_URL, token=util.BANGUMI_ACCESS_TOKEN),
)
print(t.system_prompt)
res = await t.ask(summary_origin)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import pytest

from tests import util
from yuisub.a2t import WhisperModel
from yuisub.sub import bilingual, load, translate

from . import util


def test_sub() -> None:
sub = load(util.TEST_ENG_SRT)
Expand Down Expand Up @@ -34,6 +35,7 @@ async def test_bilingual_2() -> None:
api_key=util.OPENAI_API_KEY,
base_url=util.OPENAI_BASE_URL,
bangumi_url=util.BANGUMI_URL,
bangumi_access_token=util.BANGUMI_ACCESS_TOKEN,
)
sub_bilingual = await bilingual(sub_origin=sub, sub_zh=sub_zh)

Expand Down
39 changes: 39 additions & 0 deletions tests/test_translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os

import pytest

from yuisub.translator import SubtitleTranslator

from . import util


@pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="Skipping test when running on CI")
async def test_translator_sub() -> None:
translator = SubtitleTranslator(
model=util.OPENAI_MODEL,
api_key=util.OPENAI_API_KEY,
base_url=util.OPENAI_BASE_URL,
bangumi_url=util.BANGUMI_URL,
bangumi_access_token=util.BANGUMI_ACCESS_TOKEN,
)

sub_zh, sub_bilingual = await translator.get_subtitles(sub=str(util.TEST_ENG_SRT))
sub_zh.save(util.projectPATH / "assets" / "test.zh.translator.sub.ass")
sub_bilingual.save(util.projectPATH / "assets" / "test.bilingual.translator.sub.ass")


@pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="Skipping test when running on CI")
async def test_translator_audio() -> None:
translator = SubtitleTranslator(
torch_device=util.DEVICE,
whisper_model=util.MODEL_NAME,
model=util.OPENAI_MODEL,
api_key=util.OPENAI_API_KEY,
base_url=util.OPENAI_BASE_URL,
bangumi_url=util.BANGUMI_URL,
bangumi_access_token=util.BANGUMI_ACCESS_TOKEN,
)

sub_zh, sub_bilingual = await translator.get_subtitles(audio=str(util.TEST_AUDIO))
sub_zh.save(util.projectPATH / "assets" / "test.zh.translator.audio.ass")
sub_bilingual.save(util.projectPATH / "assets" / "test.bilingual.translator.audio.ass")
5 changes: 2 additions & 3 deletions tests/util.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import os
from pathlib import Path

import torch

projectPATH = Path(__file__).resolve().parent.parent.absolute()

TEST_AUDIO = projectPATH / "assets" / "test.mp3"
TEST_ENG_SRT = projectPATH / "assets" / "eng.srt"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = "cpu" if os.environ.get("GITHUB_ACTIONS") == "true" else None
MODEL_NAME = "medium" if DEVICE == "cuda" else "tiny"

BANGUMI_URL = "https://bangumi.tv/subject/424883"
BANGUMI_ACCESS_TOKEN = ""

OPENAI_MODEL = str(os.getenv("OPENAI_MODEL")) if os.getenv("OPENAI_MODEL") else "deepseek-chat"
OPENAI_BASE_URL = str(os.getenv("OPENAI_BASE_URL")) if os.getenv("OPENAI_BASE_URL") else "https://api.deepseek.com"
Expand Down
1 change: 1 addition & 0 deletions yuisub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from yuisub.llm import Summarizer, Translator # noqa: F401
from yuisub.prompt import ORIGIN, ZH # noqa: F401
from yuisub.sub import advertisement, bilingual, load, translate # noqa: F401
from yuisub.translator import SubtitleTranslator # noqa: F401
57 changes: 17 additions & 40 deletions yuisub/__main__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
import argparse
import asyncio
import sys

from yuisub.sub import bilingual, load, translate
from yuisub import SubtitleTranslator

# ffmpeg -i test.mkv -c:a mp3 -map 0:a:0 test.mp3
# ffmpeg -i test.mkv -map 0:s:0 eng.srt
parser = argparse.ArgumentParser(description="Generate Bilingual Subtitle from audio or subtitle file")

parser = argparse.ArgumentParser()
parser.description = "Generate Bilingual Subtitle from audio or subtitle file"
# input
# Input
parser.add_argument("-a", "--AUDIO", type=str, help="Path to the audio file", required=False)
parser.add_argument("-s", "--SUB", type=str, help="Path to the input Subtitle file", required=False)
# subtitle output
# Output
parser.add_argument("-oz", "--OUTPUT_ZH", type=str, help="Path to save the Chinese ASS file", required=False)
parser.add_argument("-ob", "--OUTPUT_BILINGUAL", type=str, help="Path to save the bilingual ASS file", required=False)
# openai gpt
# OpenAI GPT
parser.add_argument("-om", "--OPENAI_MODEL", type=str, help="Openai model name", required=True)
parser.add_argument("-api", "--OPENAI_API_KEY", type=str, help="Openai API key", required=True)
parser.add_argument("-url", "--OPENAI_BASE_URL", type=str, help="Openai base URL", required=True)
# bangumi
# Bangumi
parser.add_argument("-bgm", "--BANGUMI_URL", type=str, help="Anime Bangumi URL", required=False)
parser.add_argument("-ac", "--BANGUMI_ACCESS_TOKEN", type=str, help="Anime Bangumi Access Token", required=False)
# whisper
# Whisper
parser.add_argument("-d", "--TORCH_DEVICE", type=str, help="Pytorch device to use", required=False)
parser.add_argument("-wm", "--WHISPER_MODEL", type=str, help="Whisper model to use", required=False)

Expand All @@ -33,47 +29,28 @@ async def main() -> None:
if args.AUDIO and args.SUB:
raise ValueError("Please provide only one input file, either audio or subtitle file")

if not args.AUDIO and not args.SUB:
raise ValueError("Please provide an input file, either audio or subtitle file")

if not args.OUTPUT_ZH and not args.OUTPUT_BILINGUAL:
raise ValueError("Please provide output paths for the subtitles.")

if args.AUDIO:
import torch

from yuisub.a2t import WhisperModel

if args.TORCH_DEVICE:
_DEVICE = args.TORCH_DEVICE
else:
_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if sys.platform == "darwin":
_DEVICE = "mps"

if args.WHISPER_MODEL:
_MODEL = args.WHISPER_MODEL
else:
_MODEL = "medium" if _DEVICE == "cpu" else "large-v2"

model = WhisperModel(name=_MODEL, device=_DEVICE)

sub = model.transcribe(audio=args.AUDIO)

else:
sub = load(args.SUB)

sub_zh = await translate(
sub=sub,
translator = SubtitleTranslator(
model=args.OPENAI_MODEL,
api_key=args.OPENAI_API_KEY,
base_url=args.OPENAI_BASE_URL,
bangumi_url=args.BANGUMI_URL,
bangumi_access_token=args.BANGUMI_ACCESS_TOKEN,
torch_device=args.TORCH_DEVICE,
whisper_model=args.WHISPER_MODEL,
)

sub_bilingual = await bilingual(sub_origin=sub, sub_zh=sub_zh)

sub_zh, sub_bilingual = await translator.get_subtitles(
sub=args.SUB,
audio=args.AUDIO,
)
if args.OUTPUT_ZH:
sub_zh.save(args.OUTPUT_ZH)

if args.OUTPUT_BILINGUAL:
sub_bilingual.save(args.OUTPUT_BILINGUAL)

Expand Down
6 changes: 5 additions & 1 deletion yuisub/a2t.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@

class WhisperModel:
def __init__(
self, name: str = "medium", device: str = "cuda", download_root: Optional[str] = None, in_memory: bool = False
self,
name: str = "medium",
device: Optional[Union[str, torch.device]] = None,
download_root: Optional[str] = None,
in_memory: bool = False,
):
self.model = whisper.load_model(name=name, device=device, download_root=download_root, in_memory=in_memory)

Expand Down
2 changes: 1 addition & 1 deletion yuisub/sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ async def translate(
base_url=base_url,
bangumi_info=bangumi_info,
)
print(summarizer.system_prompt)

print("Summarizing...")
# get summary
summary = await summarizer.ask(ORIGIN(origin="\n".join(trans_list)))

Expand Down
Loading

0 comments on commit f991615

Please sign in to comment.