diff --git a/test/unit/tts_pipeline/test_tts_engines.py b/test/unit/tts_pipeline/test_tts_engines.py index 1fa32c904..c49bf5462 100644 --- a/test/unit/tts_pipeline/test_tts_engines.py +++ b/test/unit/tts_pipeline/test_tts_engines.py @@ -4,7 +4,7 @@ from fastapi import HTTPException from voicevox_engine.dev.tts_engine.mock import MockTTSEngine -from voicevox_engine.tts_pipeline.tts_engine import TTSEngineManager +from voicevox_engine.tts_pipeline.tts_engine import LATEST_VERSION, TTSEngineManager def test_tts_engines_register_engine() -> None: @@ -48,6 +48,25 @@ def test_tts_engines_get_engine_existing() -> None: assert true_acquired_tts_engine == acquired_tts_engine +def test_tts_engines_get_engine_latest() -> None: + """TTSEngineManager.get_engine(LATEST_VERSION) で最新版の TTS エンジンを取得できる。""" + # Inputs + tts_engines = TTSEngineManager() + tts_engine1 = MockTTSEngine() + tts_engine2 = MockTTSEngine() + tts_engine3 = MockTTSEngine() + tts_engines.register_engine(tts_engine1, "0.0.1") + tts_engines.register_engine(tts_engine2, "0.0.2") + tts_engines.register_engine(tts_engine3, "0.1.0") + # Expects + true_acquired_tts_engine = tts_engine3 + # Outputs + acquired_tts_engine = tts_engines.get_engine(LATEST_VERSION) + + # Test + assert true_acquired_tts_engine == acquired_tts_engine + + def test_tts_engines_get_engine_missing() -> None: """TTSEngineManager.get_engine() で存在しない TTS エンジンを取得しようとするとエラーになる。""" # Inputs diff --git a/voicevox_engine/app/application.py b/voicevox_engine/app/application.py index 90611572f..8315e2dc6 100644 --- a/voicevox_engine/app/application.py +++ b/voicevox_engine/app/application.py @@ -80,11 +80,9 @@ def _get_core_characters(version: str | None) -> list[CoreCharacter]: ) app.include_router( - generate_tts_pipeline_router( - tts_engines, core_manager, preset_manager, cancellable_engine - ) + generate_tts_pipeline_router(tts_engines, preset_manager, cancellable_engine) ) - app.include_router(generate_morphing_router(tts_engines, core_manager, metas_store)) + app.include_router(generate_morphing_router(tts_engines, metas_store)) app.include_router( generate_preset_router(preset_manager, verify_mutability_allowed) ) diff --git a/voicevox_engine/app/routers/morphing.py b/voicevox_engine/app/routers/morphing.py index 458a82333..dd8199757 100644 --- a/voicevox_engine/app/routers/morphing.py +++ b/voicevox_engine/app/routers/morphing.py @@ -10,7 +10,6 @@ from starlette.background import BackgroundTask from starlette.responses import FileResponse -from voicevox_engine.core.core_initializer import CoreManager from voicevox_engine.metas.Metas import StyleId from voicevox_engine.metas.MetasStore import MetasStore from voicevox_engine.model import AudioQuery @@ -24,7 +23,7 @@ synthesis_morphing_parameter as _synthesis_morphing_parameter, ) from voicevox_engine.morphing.morphing import synthesize_morphed_wave -from voicevox_engine.tts_pipeline.tts_engine import TTSEngineManager +from voicevox_engine.tts_pipeline.tts_engine import LATEST_VERSION, TTSEngineManager from voicevox_engine.utility.file_utility import try_delete_file # キャッシュを有効化 @@ -34,9 +33,7 @@ def generate_morphing_router( - tts_engines: TTSEngineManager, - core_manager: CoreManager, - metas_store: MetasStore, + tts_engines: TTSEngineManager, metas_store: MetasStore ) -> APIRouter: """モーフィング API Router を生成する""" router = APIRouter(tags=["音声合成"]) @@ -89,7 +86,7 @@ def _synthesis_morphing( 指定された2種類のスタイルで音声を合成、指定した割合でモーフィングした音声を得ます。 モーフィングの割合は`morph_rate`で指定でき、0.0でベースのスタイル、1.0でターゲットのスタイルに近づきます。 """ - version = core_version or core_manager.latest_version() + version = core_version or LATEST_VERSION engine = tts_engines.get_engine(version) # モーフィングが許可されないキャラクターペアを拒否する diff --git a/voicevox_engine/app/routers/tts_pipeline.py b/voicevox_engine/app/routers/tts_pipeline.py index 7b113eb39..6555c844c 100644 --- a/voicevox_engine/app/routers/tts_pipeline.py +++ b/voicevox_engine/app/routers/tts_pipeline.py @@ -15,7 +15,6 @@ CancellableEngine, CancellableEngineInternalError, ) -from voicevox_engine.core.core_initializer import CoreManager from voicevox_engine.metas.Metas import StyleId from voicevox_engine.model import AudioQuery from voicevox_engine.preset.preset_manager import ( @@ -39,6 +38,7 @@ Score, ) from voicevox_engine.tts_pipeline.tts_engine import ( + LATEST_VERSION, TalkSingInvalidInputError, TTSEngineManager, ) @@ -65,7 +65,6 @@ def __init__(self, err: ParseKanaError): def generate_tts_pipeline_router( tts_engines: TTSEngineManager, - core_manager: CoreManager, preset_manager: PresetManager, cancellable_engine: CancellableEngine | None, ) -> APIRouter: @@ -85,7 +84,7 @@ def audio_query( """ 音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。 """ - version = core_version or core_manager.latest_version() + version = core_version or LATEST_VERSION engine = tts_engines.get_engine(version) accent_phrases = engine.create_accent_phrases(text, style_id) return AudioQuery( @@ -116,7 +115,7 @@ def audio_query_from_preset( """ 音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。 """ - version = core_version or core_manager.latest_version() + version = core_version or LATEST_VERSION engine = tts_engines.get_engine(version) try: presets = preset_manager.load_presets() @@ -175,7 +174,7 @@ def accent_phrases( * アクセント位置を`'`で指定する。全てのアクセント句にはアクセント位置を1つ指定する必要がある。 * アクセント句末に`?`(全角)を入れることにより疑問文の発音ができる。 """ - version = core_version or core_manager.latest_version() + version = core_version or LATEST_VERSION engine = tts_engines.get_engine(version) if is_kana: try: @@ -197,7 +196,7 @@ def mora_data( style_id: Annotated[StyleId, Query(alias="speaker")], core_version: str | SkipJsonSchema[None] = None, ) -> list[AccentPhrase]: - version = core_version or core_manager.latest_version() + version = core_version or LATEST_VERSION engine = tts_engines.get_engine(version) return engine.update_length_and_pitch(accent_phrases, style_id) @@ -211,7 +210,7 @@ def mora_length( style_id: Annotated[StyleId, Query(alias="speaker")], core_version: str | SkipJsonSchema[None] = None, ) -> list[AccentPhrase]: - version = core_version or core_manager.latest_version() + version = core_version or LATEST_VERSION engine = tts_engines.get_engine(version) return engine.update_length(accent_phrases, style_id) @@ -225,7 +224,7 @@ def mora_pitch( style_id: Annotated[StyleId, Query(alias="speaker")], core_version: str | SkipJsonSchema[None] = None, ) -> list[AccentPhrase]: - version = core_version or core_manager.latest_version() + version = core_version or LATEST_VERSION engine = tts_engines.get_engine(version) return engine.update_pitch(accent_phrases, style_id) @@ -253,7 +252,7 @@ def synthesis( ] = True, core_version: str | SkipJsonSchema[None] = None, ) -> FileResponse: - version = core_version or core_manager.latest_version() + version = core_version or LATEST_VERSION engine = tts_engines.get_engine(version) wave = engine.synthesize_wave( query, style_id, enable_interrogative_upspeak=enable_interrogative_upspeak @@ -294,8 +293,8 @@ def cancellable_synthesis( status_code=404, detail="実験的機能はデフォルトで無効になっています。使用するには引数を指定してください。", ) - version = core_version or core_manager.latest_version() try: + version = core_version or LATEST_VERSION f_name = cancellable_engine._synthesis_impl( query, style_id, request, version=version ) @@ -331,7 +330,7 @@ def multi_synthesis( style_id: Annotated[StyleId, Query(alias="speaker")], core_version: str | SkipJsonSchema[None] = None, ) -> FileResponse: - version = core_version or core_manager.latest_version() + version = core_version or LATEST_VERSION engine = tts_engines.get_engine(version) sampling_rate = queries[0].outputSamplingRate @@ -374,7 +373,7 @@ def sing_frame_audio_query( """ 歌唱音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま歌唱音声合成に利用できます。各値の意味は`Schemas`を参照してください。 """ - version = core_version or core_manager.latest_version() + version = core_version or LATEST_VERSION engine = tts_engines.get_engine(version) try: phonemes, f0, volume = engine.create_sing_phoneme_and_f0_and_volume( @@ -403,7 +402,7 @@ def sing_frame_volume( style_id: Annotated[StyleId, Query(alias="speaker")], core_version: str | SkipJsonSchema[None] = None, ) -> list[float]: - version = core_version or core_manager.latest_version() + version = core_version or LATEST_VERSION engine = tts_engines.get_engine(version) try: return engine.create_sing_volume_from_phoneme_and_f0( @@ -432,7 +431,7 @@ def frame_synthesis( """ 歌唱音声合成を行います。 """ - version = core_version or core_manager.latest_version() + version = core_version or LATEST_VERSION engine = tts_engines.get_engine(version) try: wave = engine.frame_synthsize_wave(query, style_id) @@ -528,7 +527,7 @@ def initialize_speaker( 指定されたスタイルを初期化します。 実行しなくても他のAPIは使用できますが、初回実行時に時間がかかることがあります。 """ - version = core_version or core_manager.latest_version() + version = core_version or LATEST_VERSION engine = tts_engines.get_engine(version) engine.initialize_synthesis(style_id, skip_reinit=skip_reinit) @@ -540,7 +539,7 @@ def is_initialized_speaker( """ 指定されたスタイルが初期化されているかどうかを返します。 """ - version = core_version or core_manager.latest_version() + version = core_version or LATEST_VERSION engine = tts_engines.get_engine(version) return engine.is_synthesis_initialized(style_id) diff --git a/voicevox_engine/cancellable_engine.py b/voicevox_engine/cancellable_engine.py index 2899fb5e4..a812892f3 100644 --- a/voicevox_engine/cancellable_engine.py +++ b/voicevox_engine/cancellable_engine.py @@ -19,7 +19,7 @@ from .core.core_initializer import initialize_cores from .metas.Metas import StyleId from .model import AudioQuery -from .tts_pipeline.tts_engine import make_tts_engines_from_cores +from .tts_pipeline.tts_engine import LatestVersion, make_tts_engines_from_cores class CancellableEngineInternalError(Exception): @@ -149,7 +149,7 @@ def _synthesis_impl( query: AudioQuery, style_id: StyleId, request: Request, - version: str, + version: str | LatestVersion, ) -> str: """ 音声合成を行う関数 @@ -163,7 +163,7 @@ def _synthesis_impl( request: fastapi.Request 接続確立時に受け取ったものをそのまま渡せばよい https://fastapi.tiangolo.com/advanced/using-request-directly/ - version: str + version Returns ------- @@ -245,9 +245,9 @@ def start_synthesis_subprocess( while True: try: query, style_id, version = sub_proc_con.recv() - if tts_engines.has_engine(version): + try: _engine = tts_engines.get_engine(version) - else: + except Exception: # バージョンが見つからないエラー sub_proc_con.send("") continue diff --git a/voicevox_engine/tts_pipeline/tts_engine.py b/voicevox_engine/tts_pipeline/tts_engine.py index 91de3993a..5cde61724 100644 --- a/voicevox_engine/tts_pipeline/tts_engine.py +++ b/voicevox_engine/tts_pipeline/tts_engine.py @@ -2,12 +2,15 @@ import copy import math +from typing import Final, Literal, TypeAlias import numpy as np from fastapi import HTTPException from numpy.typing import NDArray from soxr import resample +from voicevox_engine.utility.core_version_utility import get_latest_version + from ..core.core_adapter import CoreAdapter, DeviceSupport from ..core.core_initializer import CoreManager from ..core.core_wrapper import CoreWrapper @@ -697,6 +700,10 @@ def frame_synthsize_wave( return wave +LatestVersion: TypeAlias = Literal["LATEST_VERSION"] +LATEST_VERSION: Final[LatestVersion] = "LATEST_VERSION" + + class TTSEngineManager: """TTS エンジンの集まりを一括管理するマネージャー""" @@ -707,13 +714,18 @@ def versions(self) -> list[str]: """登録されたエンジンのバージョン一覧を取得する。""" return list(self._engines.keys()) + def _latest_version(self) -> str: + return get_latest_version(self.versions()) + def register_engine(self, engine: TTSEngine, version: str) -> None: """エンジンを登録する。""" self._engines[version] = engine - def get_engine(self, version: str) -> TTSEngine: + def get_engine(self, version: str | LatestVersion) -> TTSEngine: """指定バージョンのエンジンを取得する。""" - if version in self._engines: + if version == LATEST_VERSION: + return self._engines[self._latest_version()] + elif version in self._engines: return self._engines[version] raise HTTPException(status_code=422, detail="不明なバージョンです")