Skip to content

Commit

Permalink
Merge #4: Strict Memory Control & Only Access Models in Worker
Browse files Browse the repository at this point in the history
  • Loading branch information
graphemecluster authored Aug 29, 2024
2 parents 4cbbfd9 + 624fe08 commit 003e8e3
Show file tree
Hide file tree
Showing 14 changed files with 244 additions and 202 deletions.
5 changes: 3 additions & 2 deletions index.html
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ <h2 class="group-hover:text-slate-700 group-hover:text-opacity-90 transition-[co
</div>
</div>
<dialog id="about-dialog" class="modal modal-bottom sm:modal-middle">
<div class="modal-box p-0 flex flex-col sm:max-w-3xl h-[calc(100%-5rem)]">
<div class="modal-box p-0 flex flex-col sm:max-w-3xl h-[calc(100%-5rem)] overflow-hidden">
<form method="dialog">
<button type="submit" class="btn btn-ghost w-14 h-14 min-h-14 text-4.5xl absolute right-3 top-3 text-slate-500 hover:bg-opacity-10" aria-label="關閉"><span class="icon-close"></span></button>
</form>
Expand All @@ -93,7 +93,8 @@ <h3 class="flex items-center gap-2 mx-6 mt-6 mb-4.5"><span class="icon-info"></s
>,詳情請參閱<a href="https://github.com/hkilang/TTS" target="_blank">原始碼</a>
</p>
<p>如有任何查詢,歡迎電郵至 <a href="mailto:info@hkilang.org" target="_blank">info@hkilang.org</a><a href="mailto:lchaakming@eduhk.hk" target="_blank">lchaakming@eduhk.hk</a></p>
<img src="./assets/credit-logos.svg" alt="康樂及文化事務署標誌、政府文體旅計劃「門常開」標誌、非物質文化遺產資助計劃標誌" title="康樂及文化事務署標誌、政府文體旅計劃「門常開」標誌、非物質文化遺產資助計劃標誌" />
<p>資助單位:</p>
<img src="./assets/credit-logos.svg" alt="康樂及文化事務署標誌、政府資助計劃「門常開」標誌、非物質文化遺產資助計劃標誌" title="康樂及文化事務署標誌、政府資助計劃「門常開」標誌、非物質文化遺產資助計劃標誌" />
<ol class="list-[squared-decimal] text-slate-400 text-sm ml-11">
<li id="vits2">
Jungil Kong, Jihoon Park, Beomjeong Kim, Jeongmin Kim, Dohee Kong, and Sangjin Kim. 2023. VITS2: Improving quality and efficiency of single-stage text-to-speech with adversarial learning and architecture design. In <cite>Proc. INTERSPEECH 2023</cite>, pp. 4374–4378. Available:
Expand Down
14 changes: 7 additions & 7 deletions public/assets/credit-logos.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
76 changes: 35 additions & 41 deletions src/AudioPlayer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,25 @@ import { MdErrorOutline, MdFileDownload, MdPause, MdPlayArrow, MdRefresh, MdStop

import { getOffsetMap } from "./audio";
import { cachedFetch } from "./cache";
import { ALL_AUDIO_COMPONENTS, ALL_MODEL_COMPONENTS, DatabaseError, DOWNLOAD_TYPE_LABEL, FileNotDownloadedError, NO_AUTO_FILL, ServerError } from "./consts";
import { DOWNLOAD_TYPE_LABEL, NO_AUTO_FILL, TERMINOLOGY } from "./consts";
import { useDB } from "./db/DBContext";
import { CURRENT_AUDIO_VERSION, CURRENT_MODEL_VERSION } from "./db/version";
import { DatabaseError, ServerError } from "./errors";
import API from "./inference/api";

import type { DownloadComponentToFile, DownloadVersion, ModelComponentToFile, AudioComponentToFile, OfflineInferenceMode, AudioVersion, SentenceComponentState } from "./types";
import type { DownloadVersion, AudioComponentToFile, OfflineInferenceMode, AudioVersion, SentenceComponentState, Language, Voice } from "./types";
import type { SyntheticEvent } from "react";

const context = new AudioContext({ sampleRate: 44100 });
const audioCache = new Map<string, Map<string, AudioBuffer>>();

export class FileNotDownloadedError extends Error {
override name = "FileNotDownloadedError";
constructor(inferenceMode: OfflineInferenceMode, language: Language, voice: Voice, isComplete?: boolean, options?: ErrorOptions) {
super(`${TERMINOLOGY[language]}${TERMINOLOGY[voice]}${DOWNLOAD_TYPE_LABEL[inferenceMode]}尚未下載${isComplete ? "" : "完成"}`, options);
}
}

export default function AudioPlayer({
sentence: {
language,
Expand Down Expand Up @@ -69,55 +77,39 @@ export default function AudioPlayer({
const { db, error: dbInitError, retry: dbInitRetry } = useDB();
const [downloadError, setDownloadError] = useState<Error>();
const [downloadRetryCounter, downloadRetry] = useReducer((n: number) => n + 1, 0);
const [download, setDownload] = useState<DownloadComponentToFile>();
const [downloadVersion, setDownloadVersion] = useState<DownloadVersion>();

const store = inferenceMode === "offline" ? "models" : "audios";
const CURRENT_VERSION = inferenceMode === "offline" ? CURRENT_MODEL_VERSION : CURRENT_AUDIO_VERSION;
const ALL_COMPONENTS = inferenceMode === "offline" ? ALL_MODEL_COMPONENTS : ALL_AUDIO_COMPONENTS;

useEffect(() => {
async function getDownloadComponents() {
if (inferenceMode === "online" || !db || download || currSettingsDialogPage) return;
if (inferenceMode === "online" || !db || downloadVersion || currSettingsDialogPage) return;
setDownloadVersion(undefined);
setDownloadError(undefined);
setBuffer(undefined);
try {
const availableFiles = await db.getAllFromIndex(store, "language_voice", [language, voice]);
if (availableFiles.length !== ALL_COMPONENTS.length) {
setDownloadError(new FileNotDownloadedError(inferenceMode, language, voice, !availableFiles.length));
setDownloadState({ inferenceMode, language, voice, status: availableFiles.length ? "incomplete" : "available_for_download" });
return;
}
const components = {} as DownloadComponentToFile;
const versions = new Set<DownloadVersion>();
for (const file of availableFiles) {
components[file.component] = file;
versions.add(file.version);
}
if (versions.size !== 1) {
setDownloadError(new FileNotDownloadedError(inferenceMode, language, voice));
setDownloadState({ inferenceMode, language, voice, status: "incomplete" });
return;
}
setDownload(components);
setDownloadState({ inferenceMode, language, voice, status: versions.values().next().value === CURRENT_VERSION ? "latest" : "new_version_available" });
const fileStatus = await db.get(`${store}_status`, `${language}/${voice}`);
const isComplete = fileStatus && !fileStatus.missingComponents.size;
if (isComplete) setDownloadVersion(fileStatus.version);
else setDownloadError(new FileNotDownloadedError(inferenceMode, language, voice, !fileStatus));
setDownloadState({ inferenceMode, language, voice, status: fileStatus ? isComplete ? fileStatus.version === CURRENT_VERSION ? "latest" : "new_version_available" : "incomplete" : "available_for_download" });
}
catch (error) {
setDownloadError(new DatabaseError(`無法存取語音${DOWNLOAD_TYPE_LABEL[inferenceMode]}:資料庫出錯`, { cause: error }));
setDownloadError(new DatabaseError(`無法取得${DOWNLOAD_TYPE_LABEL[inferenceMode]}狀態:資料庫出錯`, { cause: error }));
}
}
void getDownloadComponents();
// `inferenceMode` and `voiceSpeed` intentionally excluded
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [db, download, language, voice, setDownloadState, currSettingsDialogPage, downloadRetryCounter]);
}, [db, language, voice, inferenceMode, voiceSpeed, setDownloadState, currSettingsDialogPage, downloadRetryCounter]);

const [generationError, setGenerationError] = useState<Error>();
const [generationRetryCounter, generationRetry] = useReducer((n: number) => n + 1, 0);
const text = syllables.join(" ");
useEffect(() => {
if (inferenceMode !== "online" && !download) return;
const [{ version }] = inferenceMode === "online" ? [{ version: "main" }] : Object.values(download!);
if (inferenceMode !== "online" && !downloadVersion) return;
async function generateAudio() {
const key = `${inferenceMode}/${voiceSpeed}/${version}/${language}/${voice}`;
const key = `${inferenceMode}/${voiceSpeed}/${downloadVersion}/${language}/${voice}`;
let textToBuffer = audioCache.get(key);
if (!textToBuffer) audioCache.set(key, textToBuffer = new Map<string, AudioBuffer>());
let buffer = textToBuffer.get(text);
Expand All @@ -136,22 +128,28 @@ export default function AudioPlayer({
}
}
catch (error) {
throw error instanceof ServerError ? error : new ServerError("載入失敗", undefined, { cause: error });
throw error instanceof ServerError ? error : new ServerError("無法載入音訊:網絡或伺服器錯誤", undefined, { cause: error });
}
break;
case "offline": {
const channelData = await API.infer(language, download as ModelComponentToFile, syllables, voiceSpeed);
const channelData = await API.infer(language, voice, syllables, voiceSpeed);
buffer = context.createBuffer(1, channelData.length, 44100);
buffer.copyToChannel(channelData, 0);
break;
}
case "lightweight": {
const components: Partial<AudioComponentToFile> = {};
const buffers = await Promise.all(syllables.map(async phrase => {
const component = phrase.includes(" ") ? "words" : "chars";
const offset = (await getOffsetMap(version as AudioVersion, language, voice, component)).get(phrase);
const offset = (await getOffsetMap(downloadVersion as AudioVersion, language, voice, component)).get(phrase);
if (!offset) return context.createBuffer(1, 8820, 44100);
const data = (download as AudioComponentToFile)[component].file;
return context.decodeAudioData(data.slice(...offset));
try {
components[component] ??= (await db!.get("audios", `${language}/${voice}/${component}`))!.file;
}
catch (error) {
throw new DatabaseError("無法存取語音數據:資料庫出錯", { cause: error });
}
return context.decodeAudioData(components[component].slice(...offset));
}));
buffer = context.createBuffer(1, buffers.reduce((length, buffer) => length + buffer.length, 0), 44100);
const channelData = buffer.getChannelData(0);
Expand All @@ -176,9 +174,8 @@ export default function AudioPlayer({
setGenerationError(undefined);
setBuffer(undefined);
void generateAudio();
// `inferenceMode` and `voiceSpeed` intentionally excluded
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [language, voice, download, text, generationRetryCounter]);
}, [language, voice, inferenceMode, voiceSpeed, downloadVersion, text, generationRetryCounter]);

useEffect(() => {
if (buffer && isPlaying === null) playAudio();
Expand Down Expand Up @@ -286,10 +283,7 @@ export default function AudioPlayer({
</>}
</button>
</div>
: <div className="flex items-center gap-3 font-medium">
{db ? inferenceMode === "online" || download ? "正在生成語音,請稍候……" : `正在存取語音${DOWNLOAD_TYPE_LABEL[inferenceMode]}……` : "資料庫載入中……"}
<span className="loading loading-spinner max-sm:w-8 sm:loading-lg" />
</div>}
: <span className="loading loading-spinner max-sm:w-8 sm:loading-lg" />}
</div>}
</div>;
}
8 changes: 4 additions & 4 deletions src/SettingsDialog.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ const SettingsDialog = forwardRef<HTMLDialogElement, SettingDialogProps>(functio
const downloadManagerInferenceMode = currSettingsDialogPage?.slice(0, -15) as OfflineInferenceMode;

return <dialog ref={ref} className="modal modal-bottom sm:modal-middle">
<div className="modal-box p-0 flex flex-col sm:max-w-3xl h-[calc(100%-5rem)]">
<div className="modal-box p-0 flex flex-col sm:max-w-3xl h-[calc(100%-5rem)] overflow-hidden">
<form method="dialog">
<button type="submit" className="btn btn-ghost w-14 h-14 min-h-14 text-4.5xl absolute right-3 top-3 text-slate-500 hover:bg-opacity-10" aria-label="關閉">
<span>
Expand Down Expand Up @@ -72,7 +72,7 @@ const SettingsDialog = forwardRef<HTMLDialogElement, SettingDialogProps>(functio
{ALL_INFERENCE_MODES.map(mode => {
const currModeDownloadState = downloadState.get(mode as OfflineInferenceMode)!;
return <li key={mode} className="relative">
<label className="btn btn-ghost gap-0 w-full rounded-none text-left font-normal pl-2 pr-4 py-4 h-auto min-h-0 border-0 border-b border-b-slate-300 text-slate-700 hover:border-b hover:bg-opacity-10">
<label className="flex items-center text-sm/4 pl-2 pr-4 py-4 border-b border-b-slate-300 text-slate-700 cursor-pointer transition-colors hover:bg-base-content hover:bg-opacity-10">
<div className="text-2xl flex items-center px-2">{INFERENCE_MODE_TO_ICON[mode]}</div>
<div className="flex-1 flex flex-col">
<div className="text-xl font-medium">{INFERENCE_MODE_TO_LABEL[mode]}</div>
Expand Down Expand Up @@ -100,7 +100,7 @@ const SettingsDialog = forwardRef<HTMLDialogElement, SettingDialogProps>(functio
</ul>
<h4 className="px-4 py-2 border-b">選項</h4>
<ul>
<li className="flex items-center w-full pl-2 pr-4 py-4 h-auto min-h-0 border-0 border-b border-b-slate-300 text-slate-700 hover:border-b hover:bg-opacity-10">
<li className="flex items-center pl-2 pr-4 py-4 border-b border-b-slate-300 text-slate-700">
<div className="text-2xl flex items-center px-2">
<MdSpeed size="1.25em" />
</div>
Expand All @@ -120,7 +120,7 @@ const SettingsDialog = forwardRef<HTMLDialogElement, SettingDialogProps>(functio
</div>
</div>
</li>
<li className="flex items-center w-full pl-2 pr-4 py-4 h-auto min-h-0 border-0 border-b border-b-slate-300 text-slate-700 hover:border-b hover:bg-opacity-10">
<li className="flex items-center pl-2 pr-4 py-4 border-b border-b-slate-300 text-slate-700">
<div className="text-2xl flex items-center px-2">
<MdShowChart size="1.25em" />
</div>
Expand Down
13 changes: 10 additions & 3 deletions src/audio.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { cachedFetch } from "./cache";
import { AUDIO_PATH_PREFIX, ServerError } from "./consts";
import { AUDIO_PATH_PREFIX } from "./consts";
import { ServerError } from "./errors";

import type { AudioComponent, AudioVersion, Language, Voice } from "./types";

Expand All @@ -11,8 +12,14 @@ export async function getOffsetMap(version: AudioVersion, language: Language, vo
const path = `${version}/${language}/${voice}/${component}`;
let offsetMap = offsetCache.get(path);
if (offsetMap) return offsetMap;
const response = await cachedFetch(`${AUDIO_PATH_PREFIX}@${path}.csv`);
if (!response.ok) throw new ServerError("載入失敗");
let response: Response;
try {
response = await cachedFetch(`${AUDIO_PATH_PREFIX}@${path}.csv`);
if (!response.ok) throw new ServerError("無法載入語音數據", await response.text());
}
catch (error) {
throw error instanceof ServerError ? error : new ServerError("無法載入語音數據:網絡或伺服器錯誤", undefined, { cause: error });
}
const iter = (await response.text())[Symbol.iterator]();
// Skip header
for (const char of iter) {
Expand Down
20 changes: 1 addition & 19 deletions src/consts.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export const INFERENCE_MODE_TO_LABEL: Record<InferenceMode, string> = {
export const INFERENCE_MODE_TO_DESCRIPTION: Record<InferenceMode, string> = {
online: "在伺服器產生音訊。請注意,使用此模式可能會產生相關網絡費用。",
offline: "毋須網絡連線,直接於裝置進行運算並產生音訊。請注意,此模式僅適用於電腦或有大量可用記憶體的裝置,且需空間儲存模型。",
lightweight: "以輕巧方式快速於裝置產生音訊,質素較其餘兩個模式遜色。適用如記憶體容量較少的裝置。請注意,此模式仍需空間儲存數據。",
lightweight: "以輕巧方式快速於裝置產生音訊,質素較其餘兩個模式遜色。適用於記憶體容量較少的裝置。請注意,此模式仍需空間儲存數據。",
};

export const INFERENCE_MODE_TO_ICON: Record<InferenceMode, JSX.Element> = {
Expand Down Expand Up @@ -143,21 +143,3 @@ export const NO_AUTO_FILL = {
autoCapitalize: "off",
spellCheck: "false",
} as const;

export class ServerError extends Error {
constructor(name: string, ...args: Parameters<ErrorConstructor>) {
super(...args);
this.name = name;
}
}

export class DatabaseError extends Error {
override name = "DatabaseError";
}

export class FileNotDownloadedError extends Error {
override name = "FileNotDownloadedError";
constructor(inferenceMode: OfflineInferenceMode, language: Language, voice: Voice, isComplete?: boolean, options?: ErrorOptions) {
super(`${TERMINOLOGY[language]}${TERMINOLOGY[voice]}${DOWNLOAD_TYPE_LABEL[inferenceMode]}尚未下載${isComplete ? "" : "完成"}`, options);
}
}
32 changes: 2 additions & 30 deletions src/db/DBContext.tsx
Original file line number Diff line number Diff line change
@@ -1,38 +1,10 @@
import { useEffect, useReducer, createContext, useContext } from "react";

import { openDB } from "idb";
import { getDBInstance } from "./instance";

import type { TTSDB } from "../types";
import type { IDBPDatabase } from "idb";
import type { TTSDatabase } from "./instance";
import type { Dispatch } from "react";

type TTSDatabase = IDBPDatabase<TTSDB>;
let dbInstance: TTSDatabase | undefined;
let pendingDBInstance: Promise<TTSDatabase> | undefined;

function getDBInstance() {
async function createDBInstance() {
try {
return dbInstance = await openDB<TTSDB>("TTS", 2, {
upgrade(db) {
if (!db.objectStoreNames.contains("models")) {
const store = db.createObjectStore("models", { keyPath: "path" });
store.createIndex("language_voice", ["language", "voice"]);
}
if (!db.objectStoreNames.contains("audios")) {
const store = db.createObjectStore("audios", { keyPath: "path" });
store.createIndex("language_voice", ["language", "voice"]);
}
},
});
}
finally {
pendingDBInstance = undefined;
}
}
return dbInstance || (pendingDBInstance ||= createDBInstance());
}

interface DBState {
db: TTSDatabase | undefined;
error: Error | undefined;
Expand Down
Loading

0 comments on commit 003e8e3

Please sign in to comment.