Skip to content

Commit

Permalink
整理: GET /supported_devices API を tts_pipeline router へ移動 (#1444)
Browse files Browse the repository at this point in the history
refactor: `supported_devices` API を `tts_pipeline` router へ移動
  • Loading branch information
tarepan authored Jun 29, 2024
1 parent 925c4d5 commit abcfa78
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 44 deletions.
4 changes: 1 addition & 3 deletions voicevox_engine/app/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,7 @@ def _get_core_characters(version: str | None) -> list[CoreCharacter]:
generate_library_router(library_manager, verify_mutability_allowed)
)
app.include_router(generate_user_dict_router(user_dict, verify_mutability_allowed))
app.include_router(
generate_engine_info_router(core_version_list, tts_engines, engine_manifest)
)
app.include_router(generate_engine_info_router(core_version_list, engine_manifest))
app.include_router(
generate_setting_router(
setting_loader, engine_manifest.brand_name, verify_mutability_allowed
Expand Down
42 changes: 2 additions & 40 deletions voicevox_engine/app/routers/engine_info.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,13 @@
"""エンジンの情報機能を提供する API Router"""

from typing import Self

from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from pydantic.json_schema import SkipJsonSchema
from fastapi import APIRouter

from voicevox_engine import __version__
from voicevox_engine.core.core_adapter import DeviceSupport
from voicevox_engine.engine_manifest import EngineManifest
from voicevox_engine.tts_pipeline.tts_engine import LATEST_VERSION, TTSEngineManager


class SupportedDevicesInfo(BaseModel):
"""
対応しているデバイスの情報
"""

cpu: bool = Field(description="CPUに対応しているか")
cuda: bool = Field(description="CUDA(Nvidia GPU)に対応しているか")
dml: bool = Field(description="DirectML(Nvidia GPU/Radeon GPU等)に対応しているか")

@classmethod
def generate_from(cls, device_support: DeviceSupport) -> Self:
"""`DeviceSupport` インスタンスからこのインスタンスを生成する。"""
return cls(
cpu=device_support.cpu,
cuda=device_support.cuda,
dml=device_support.dml,
)


def generate_engine_info_router(
core_version_list: list[str],
tts_engine_manager: TTSEngineManager,
engine_manifest_data: EngineManifest,
core_version_list: list[str], engine_manifest_data: EngineManifest
) -> APIRouter:
"""エンジン情報 API Router を生成する"""
router = APIRouter(tags=["その他"])
Expand All @@ -49,17 +22,6 @@ async def core_versions() -> list[str]:
"""利用可能なコアのバージョン一覧を取得します。"""
return core_version_list

@router.get("/supported_devices")
def supported_devices(
core_version: str | SkipJsonSchema[None] = None,
) -> SupportedDevicesInfo:
"""対応デバイスの一覧を取得します。"""
version = core_version or LATEST_VERSION
supported_devices = tts_engine_manager.get_engine(version).supported_devices
if supported_devices is None:
raise HTTPException(status_code=422, detail="非対応の機能です。")
return SupportedDevicesInfo.generate_from(supported_devices)

@router.get("/engine_manifest")
async def engine_manifest() -> EngineManifest:
"""エンジンマニフェストを取得します。"""
Expand Down
33 changes: 32 additions & 1 deletion voicevox_engine/app/routers/tts_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import zipfile
from tempfile import NamedTemporaryFile, TemporaryFile
from typing import Annotated
from typing import Annotated, Self

import soundfile
from fastapi import APIRouter, HTTPException, Query, Request
Expand All @@ -15,6 +15,7 @@
CancellableEngine,
CancellableEngineInternalError,
)
from voicevox_engine.core.core_adapter import DeviceSupport
from voicevox_engine.metas.Metas import StyleId
from voicevox_engine.model import AudioQuery
from voicevox_engine.preset.preset_manager import (
Expand Down Expand Up @@ -63,6 +64,25 @@ def __init__(self, err: ParseKanaError):
super().__init__(text=err.text, error_name=err.errname, error_args=err.kwargs)


class SupportedDevicesInfo(BaseModel):
"""
対応しているデバイスの情報
"""

cpu: bool = Field(description="CPUに対応しているか")
cuda: bool = Field(description="CUDA(Nvidia GPU)に対応しているか")
dml: bool = Field(description="DirectML(Nvidia GPU/Radeon GPU等)に対応しているか")

@classmethod
def generate_from(cls, device_support: DeviceSupport) -> Self:
"""`DeviceSupport` インスタンスからこのインスタンスを生成する。"""
return cls(
cpu=device_support.cpu,
cuda=device_support.cuda,
dml=device_support.dml,
)


def generate_tts_pipeline_router(
tts_engines: TTSEngineManager,
preset_manager: PresetManager,
Expand Down Expand Up @@ -543,4 +563,15 @@ def is_initialized_speaker(
engine = tts_engines.get_engine(version)
return engine.is_synthesis_initialized(style_id)

@router.get("/supported_devices", tags=["その他"])
def supported_devices(
core_version: str | SkipJsonSchema[None] = None,
) -> SupportedDevicesInfo:
"""対応デバイスの一覧を取得します。"""
version = core_version or LATEST_VERSION
supported_devices = tts_engines.get_engine(version).supported_devices
if supported_devices is None:
raise HTTPException(status_code=422, detail="非対応の機能です。")
return SupportedDevicesInfo.generate_from(supported_devices)

return router

0 comments on commit abcfa78

Please sign in to comment.