diff --git a/Cargo.lock b/Cargo.lock index f6c4e39d80..c86b84db65 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17902,6 +17902,7 @@ dependencies = [ "tokio-stream", "tokio-util", "tracing", + "transcript", "url", "uuid", "vad-ext", @@ -17932,6 +17933,7 @@ dependencies = [ "tokio", "tokio-stream", "tracing", + "transcript", ] [[package]] @@ -19816,6 +19818,18 @@ dependencies = [ "ws-utils", ] +[[package]] +name = "transcript" +version = "0.1.0" +dependencies = [ + "data", + "owhisper-interface", + "serde", + "serde_json", + "specta", + "uuid", +] + [[package]] name = "transpose" version = "0.2.3" diff --git a/Cargo.toml b/Cargo.toml index e7ef294d5b..8efe59a387 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -108,6 +108,7 @@ hypr-transcribe-deepgram = { path = "crates/transcribe-deepgram", package = "tra hypr-transcribe-openai = { path = "crates/transcribe-openai", package = "transcribe-openai" } hypr-transcribe-proxy = { path = "crates/transcribe-proxy", package = "transcribe-proxy" } hypr-transcribe-whisper-local = { path = "crates/transcribe-whisper-local", package = "transcribe-whisper-local" } +hypr-transcript = { path = "crates/transcript", package = "transcript" } hypr-turso = { path = "crates/turso", package = "turso" } hypr-vad = { path = "crates/vad", package = "vad" } hypr-vad-ext = { path = "crates/vad-ext", package = "vad-ext" } diff --git a/apps/desktop/src/components/main/body/sessions/note-input/transcript/shared/index.tsx b/apps/desktop/src/components/main/body/sessions/note-input/transcript/shared/index.tsx index 94eea4354d..9655330f0f 100644 --- a/apps/desktop/src/components/main/body/sessions/note-input/transcript/shared/index.tsx +++ b/apps/desktop/src/components/main/body/sessions/note-input/transcript/shared/index.tsx @@ -1,9 +1,8 @@ import { TriangleAlert } from "lucide-react"; -import { type RefObject, useCallback, useMemo, useRef, useState } from "react"; +import { type RefObject, useCallback, useRef, useState } from "react"; import { useHotkeys } from "react-hotkeys-hook"; import type { DegradedError } from "@hypr/plugin-listener"; -import type { RuntimeSpeakerHint } from "@hypr/transcript"; import { cn } from "@hypr/utils"; import { useAudioPlayer } from "../../../../../../../contexts/audio-player/provider"; @@ -43,44 +42,7 @@ export function TranscriptContainer({ const editable = sessionMode === "inactive" && Object.keys(operations ?? {}).length > 0; - const partialWordsByChannel = useListener( - (state) => state.partialWordsByChannel, - ); - const partialHintsByChannel = useListener( - (state) => state.partialHintsByChannel, - ); - - const partialWords = useMemo( - () => Object.values(partialWordsByChannel).flat(), - [partialWordsByChannel], - ); - - const partialHints = useMemo(() => { - const channelIndices = Object.keys(partialWordsByChannel) - .map(Number) - .sort((a, b) => a - b); - - const offsetByChannel = new Map(); - let currentOffset = 0; - for (const channelIndex of channelIndices) { - offsetByChannel.set(channelIndex, currentOffset); - currentOffset += partialWordsByChannel[channelIndex]?.length ?? 0; - } - - const reindexedHints: RuntimeSpeakerHint[] = []; - for (const channelIndex of channelIndices) { - const hints = partialHintsByChannel[channelIndex] ?? []; - const offset = offsetByChannel.get(channelIndex) ?? 0; - for (const hint of hints) { - reindexedHints.push({ - ...hint, - wordIndex: hint.wordIndex + offset, - }); - } - } - - return reindexedHints; - }, [partialWordsByChannel, partialHintsByChannel]); + const partialWords = useListener((state) => state.partialWords); const containerRef = useRef(null); const [scrollElement, setScrollElement] = useState( @@ -167,11 +129,7 @@ export function TranscriptContainer({ ? partialWords : [] } - partialHints={ - index === transcriptIds.length - 1 && currentActive - ? partialHints - : [] - } + partialHints={[]} operations={operations} /> {index < transcriptIds.length - 1 && } diff --git a/apps/desktop/src/hooks/useRunBatch.ts b/apps/desktop/src/hooks/useRunBatch.ts index 159b2a7bd4..8086039442 100644 --- a/apps/desktop/src/hooks/useRunBatch.ts +++ b/apps/desktop/src/hooks/useRunBatch.ts @@ -12,14 +12,14 @@ import { updateTranscriptHints, updateTranscriptWords, } from "../store/transcript/utils"; -import type { HandlePersistCallback } from "../store/zustand/listener/transcript"; +import type { BatchPersistCallback } from "../store/zustand/listener/batch"; import { type Tab, useTabs } from "../store/zustand/tabs"; import { id } from "../utils"; import { useKeywords } from "./useKeywords"; import { useSTTConnection } from "./useSTTConnection"; type RunOptions = { - handlePersist?: HandlePersistCallback; + handlePersist?: BatchPersistCallback; model?: string; baseUrl?: string; apiKey?: string; @@ -99,10 +99,10 @@ export const useRunBatch = (sessionId: string) => { speaker_hints: "[]", }); - const handlePersist: HandlePersistCallback | undefined = + const handlePersist: BatchPersistCallback | undefined = options?.handlePersist; - const persist = + const persist: BatchPersistCallback = handlePersist ?? ((words, hints) => { if (words.length === 0) { diff --git a/apps/desktop/src/hooks/useStartListening.ts b/apps/desktop/src/hooks/useStartListening.ts index ca3c1f1d19..db25e48fb9 100644 --- a/apps/desktop/src/hooks/useStartListening.ts +++ b/apps/desktop/src/hooks/useStartListening.ts @@ -1,6 +1,7 @@ import { useCallback } from "react"; import { commands as analyticsCommands } from "@hypr/plugin-analytics"; +import type { SpeakerHint, TranscriptWord } from "@hypr/plugin-listener"; import { useConfigValue } from "../config/use-config"; import { useListener } from "../contexts/listener"; @@ -55,7 +56,10 @@ export function useStartListening(sessionId: string) { stt_model: conn.model, }); - const handlePersist: HandlePersistCallback = (words, hints) => { + const handlePersist: HandlePersistCallback = ( + words: TranscriptWord[], + speakerHints: SpeakerHint[], + ) => { if (words.length === 0) { return; } @@ -64,49 +68,24 @@ export function useStartListening(sessionId: string) { const existingWords = parseTranscriptWords(store, transcriptId); const existingHints = parseTranscriptHints(store, transcriptId); - const newWords: WordWithId[] = []; - const newWordIds: string[] = []; - - words.forEach((word) => { - const wordId = id(); - - newWords.push({ - id: wordId, - text: word.text, - start_ms: word.start_ms, - end_ms: word.end_ms, - channel: word.channel, - }); - - newWordIds.push(wordId); - }); - - const newHints: SpeakerHintWithId[] = []; - - if (conn.provider === "deepgram") { - hints.forEach((hint) => { - if (hint.data.type !== "provider_speaker_index") { - return; - } - - const wordId = newWordIds[hint.wordIndex]; - const word = words[hint.wordIndex]; - if (!wordId || !word) { - return; - } - - newHints.push({ - id: id(), - word_id: wordId, - type: "provider_speaker_index", - value: JSON.stringify({ - provider: hint.data.provider ?? conn.provider, - channel: hint.data.channel ?? word.channel, - speaker_index: hint.data.speaker_index, - }), - }); - }); - } + const newWords: WordWithId[] = words.map((w) => ({ + id: w.id, + text: w.text, + start_ms: w.start_ms, + end_ms: w.end_ms, + channel: w.channel, + })); + + const newHints: SpeakerHintWithId[] = speakerHints.map((h) => ({ + id: id(), + word_id: h.word_id, + type: "provider_speaker_index", + value: JSON.stringify({ + provider: conn.provider, + channel: words.find((w) => w.id === h.word_id)?.channel ?? 0, + speaker_index: h.speaker_index, + }), + })); updateTranscriptWords(store, transcriptId, [ ...existingWords, diff --git a/apps/desktop/src/store/zustand/listener/batch.ts b/apps/desktop/src/store/zustand/listener/batch.ts index 2f7ef0f9c5..f8b363e995 100644 --- a/apps/desktop/src/store/zustand/listener/batch.ts +++ b/apps/desktop/src/store/zustand/listener/batch.ts @@ -1,15 +1,20 @@ import type { StoreApi } from "zustand"; -import type { BatchResponse, StreamResponse } from "@hypr/plugin-listener2"; +import type { BatchResponse } from "@hypr/plugin-listener2"; +import type { SpeakerHint, TranscriptWord } from "@hypr/plugin-listener2"; import { ChannelProfile, type RuntimeSpeakerHint, type WordLike, } from "../../../utils/segment"; -import type { HandlePersistCallback } from "./transcript"; import { transformWordEntries } from "./utils"; +export type BatchPersistCallback = ( + words: WordLike[], + hints: RuntimeSpeakerHint[], +) => void; + export type BatchPhase = "importing" | "transcribing"; export type BatchState = { @@ -22,7 +27,7 @@ export type BatchState = { phase?: BatchPhase; } >; - batchPersist: Record; + batchPersist: Record; }; export type BatchActions = { @@ -30,12 +35,13 @@ export type BatchActions = { handleBatchResponse: (sessionId: string, response: BatchResponse) => void; handleBatchResponseStreamed: ( sessionId: string, - response: StreamResponse, + words: TranscriptWord[], + speakerHints: SpeakerHint[], percentage: number, ) => void; handleBatchFailed: (sessionId: string, error: string) => void; clearBatchSession: (sessionId: string) => void; - setBatchPersist: (sessionId: string, callback: HandlePersistCallback) => void; + setBatchPersist: (sessionId: string, callback: BatchPersistCallback) => void; clearBatchPersist: (sessionId: string) => void; }; @@ -83,27 +89,32 @@ export const createBatchSlice = ( }); }, - handleBatchResponseStreamed: (sessionId, response, percentage) => { + handleBatchResponseStreamed: (sessionId, words, speakerHints, percentage) => { const persist = get().batchPersist[sessionId]; - if (persist && response.type === "Results") { - const channelIndex = response.channel_index[0]; - const alternative = response.channel.alternatives[0]; - - if (channelIndex !== undefined && alternative) { - const [words, hints] = transformWordEntries( - alternative.words, - alternative.transcript, - channelIndex, - ); + if (persist && words.length > 0) { + const wordLikes: WordLike[] = words.map((w) => ({ + text: w.text, + start_ms: w.start_ms, + end_ms: w.end_ms, + channel: w.channel, + })); + + const hints: RuntimeSpeakerHint[] = speakerHints.map((h) => { + const wordIndex = words.findIndex((w) => w.id === h.word_id); + return { + wordIndex: wordIndex >= 0 ? wordIndex : 0, + data: { + type: "provider_speaker_index" as const, + speaker_index: h.speaker_index, + }, + }; + }); - if (words.length > 0) { - persist(words, hints); - } - } + persist(wordLikes, hints); } - const isComplete = response.type === "Results" && response.from_finalize; + const isComplete = percentage >= 1; set((state) => ({ ...state, diff --git a/apps/desktop/src/store/zustand/listener/general.test.ts b/apps/desktop/src/store/zustand/listener/general.test.ts index 6ed7331a69..4e32567663 100644 --- a/apps/desktop/src/store/zustand/listener/general.test.ts +++ b/apps/desktop/src/store/zustand/listener/general.test.ts @@ -39,35 +39,7 @@ describe("General Listener Slice", () => { const sessionId = "session-456"; const { handleBatchResponseStreamed, getSessionMode } = store.getState(); - const mockResponse = { - type: "Results" as const, - start: 0, - duration: 5, - is_final: false, - speech_final: false, - from_finalize: false, - channel: { - alternatives: [ - { - transcript: "test", - words: [], - confidence: 0.9, - }, - ], - }, - metadata: { - request_id: "test-request", - model_info: { - name: "test-model", - version: "1.0", - arch: "test-arch", - }, - model_uuid: "test-uuid", - }, - channel_index: [0], - }; - - handleBatchResponseStreamed(sessionId, mockResponse, 0.5); + handleBatchResponseStreamed(sessionId, [], [], 0.5); expect(getSessionMode(sessionId)).toBe("running_batch"); }); }); @@ -78,35 +50,7 @@ describe("General Listener Slice", () => { const { handleBatchResponseStreamed, clearBatchSession } = store.getState(); - const mockResponse = { - type: "Results" as const, - start: 0, - duration: 5, - is_final: false, - speech_final: false, - from_finalize: false, - channel: { - alternatives: [ - { - transcript: "test", - words: [], - confidence: 0.9, - }, - ], - }, - metadata: { - request_id: "test-request", - model_info: { - name: "test-model", - version: "1.0", - arch: "test-arch", - }, - model_uuid: "test-uuid", - }, - channel_index: [0], - }; - - handleBatchResponseStreamed(sessionId, mockResponse, 0.5); + handleBatchResponseStreamed(sessionId, [], [], 0.5); expect(store.getState().batch[sessionId]).toEqual({ percentage: 0.5, isComplete: false, diff --git a/apps/desktop/src/store/zustand/listener/general.ts b/apps/desktop/src/store/zustand/listener/general.ts index 0936839c9d..04179f1dae 100644 --- a/apps/desktop/src/store/zustand/listener/general.ts +++ b/apps/desktop/src/store/zustand/listener/general.ts @@ -15,7 +15,6 @@ import { type SessionLifecycleEvent, type SessionParams, type SessionProgressEvent, - type StreamResponse, } from "@hypr/plugin-listener"; import { type BatchParams, @@ -26,7 +25,7 @@ import { commands as settingsCommands } from "@hypr/plugin-settings"; import { fromResult } from "../../../effect"; import { buildSessionPath } from "../../tinybase/persister/shared/paths"; -import type { BatchActions, BatchState } from "./batch"; +import type { BatchActions, BatchPersistCallback, BatchState } from "./batch"; import type { HandlePersistCallback, TranscriptActions } from "./transcript"; type LiveSessionStatus = "inactive" | "active" | "finalizing"; @@ -65,7 +64,7 @@ export type GeneralActions = { setMuted: (value: boolean) => void; runBatch: ( params: BatchParams, - options?: { handlePersist?: HandlePersistCallback; sessionId?: string }, + options?: { handlePersist?: BatchPersistCallback; sessionId?: string }, ) => Promise; getSessionMode: (sessionId: string) => SessionMode; }; @@ -311,9 +310,12 @@ export const createGeneralSlice = < }; }), ); - } else if (payload.type === "stream_response") { - const response = payload.response; - get().handleTranscriptResponse(response as unknown as StreamResponse); + } else if (payload.type === "transcript_update") { + get().handleTranscriptUpdate( + payload.new_final_words, + payload.speaker_hints, + payload.partial_words, + ); } else if (payload.type === "mic_muted") { set((state) => mutate(state, (draft) => { @@ -529,7 +531,8 @@ export const createGeneralSlice = < if (payload.type === "batchProgress") { get().handleBatchResponseStreamed( sessionId, - payload.response, + payload.words, + payload.speaker_hints, payload.percentage, ); diff --git a/apps/desktop/src/store/zustand/listener/transcript.test.ts b/apps/desktop/src/store/zustand/listener/transcript.test.ts index 3f2ae661b6..232d5fb96d 100644 --- a/apps/desktop/src/store/zustand/listener/transcript.test.ts +++ b/apps/desktop/src/store/zustand/listener/transcript.test.ts @@ -1,9 +1,8 @@ -import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { describe, expect, test, vi } from "vitest"; import { createStore } from "zustand"; -import type { StreamResponse, StreamWord } from "@hypr/plugin-listener"; +import type { PartialWord, TranscriptWord } from "@hypr/plugin-listener"; -import type { RuntimeSpeakerHint, WordLike } from "../../../utils/segment"; import { createTranscriptSlice, type TranscriptActions, @@ -16,262 +15,113 @@ const createTranscriptStore = () => { ); }; -describe("transcript slice", () => { - const defaultWords: StreamWord[] = [ - { - word: "another", - punctuated_word: "Another", - start: 0, - end: 1, - confidence: 1, - speaker: 0, - language: "en", - }, - { - word: "problem", - punctuated_word: "problem", - start: 1, - end: 2, - confidence: 1, - speaker: 1, - language: "en", - }, - ]; - - const createResponse = ({ - words, - transcript, - isFinal, - channelIndex = 0, - }: { - words: StreamWord[]; - transcript: string; - isFinal: boolean; - channelIndex?: number; - }): StreamResponse => { - return { - type: "Results", - start: 0, - duration: 0, - is_final: isFinal, - speech_final: isFinal, - from_finalize: false, - channel_index: [channelIndex], - channel: { - alternatives: [ - { - transcript, - confidence: 1, - words, - }, - ], - }, - metadata: { - request_id: "test", - model_info: { name: "model", version: "1", arch: "cpu" }, - model_uuid: "model", - }, - } satisfies StreamResponse; +function makeFinalWord( + text: string, + start_ms: number, + end_ms: number, + channel = 0, +): TranscriptWord { + return { + id: crypto.randomUUID(), + text, + start_ms, + end_ms, + channel, }; +} - type TranscriptStore = ReturnType; - let store: TranscriptStore; - - beforeEach(() => { - store = createTranscriptStore(); - }); - - afterEach(() => { - vi.useRealTimers(); - }); +function makePartialWord( + text: string, + start_ms: number, + end_ms: number, + channel = 0, +): PartialWord { + return { text, start_ms, end_ms, channel }; +} - test("stores partial words and hints from streaming updates", () => { - const initialPartial = createResponse({ - words: defaultWords, - transcript: "Another problem", - isFinal: false, - }); - - store.getState().handleTranscriptResponse(initialPartial); - - const stateAfterFirst = store.getState(); - const firstChannelWords = stateAfterFirst.partialWordsByChannel[0]; - expect(firstChannelWords).toHaveLength(2); - expect(firstChannelWords?.map((word) => word.text)).toEqual([ - " Another", - " problem", - ]); - expect(stateAfterFirst.partialHintsByChannel[0]).toHaveLength(2); - expect(stateAfterFirst.partialHintsByChannel[0]?.[0]?.wordIndex).toBe(0); - expect(stateAfterFirst.partialHintsByChannel[0]?.[1]?.wordIndex).toBe(1); - - const extendedPartial = createResponse({ - words: [ - ...defaultWords, - { - word: "exists", - punctuated_word: "exists", - start: 2, - end: 3, - confidence: 1, - speaker: 1, - language: "en", - }, - ], - transcript: "Another problem exists", - isFinal: false, - }); - - store.getState().handleTranscriptResponse(extendedPartial); +describe("transcript slice", () => { + test("handles partial words update", () => { + const store = createTranscriptStore(); + + store + .getState() + .handleTranscriptUpdate( + [], + [], + [ + makePartialWord(" Hello", 100, 500), + makePartialWord(" world", 550, 900), + ], + ); - const stateAfterSecond = store.getState(); - const updatedWords = stateAfterSecond.partialWordsByChannel[0]; - expect(updatedWords).toHaveLength(3); - expect(updatedWords?.map((word) => word.text)).toEqual([ - " Another", - " problem", - " exists", - ]); - const channelHints = stateAfterSecond.partialHintsByChannel[0] ?? []; - expect(channelHints).toHaveLength(3); - const lastPartialHint = channelHints[channelHints.length - 1]; - expect(lastPartialHint?.wordIndex).toBe(2); + const state = store.getState(); + expect(state.partialWords).toHaveLength(2); + expect(state.partialWords.map((w) => w.text)).toEqual([" Hello", " world"]); }); - test("persists only new final words", () => { + test("persists final words via callback", () => { + const store = createTranscriptStore(); const persist = vi.fn(); store.getState().setTranscriptPersist(persist); - const finalResponse = createResponse({ - words: [ - { - word: "hello", - punctuated_word: "Hello", - start: 0, - end: 0.5, - confidence: 1, - speaker: 0, - language: "en", - }, - { - word: "world", - punctuated_word: "world", - start: 0.5, - end: 1.5, - confidence: 1, - speaker: null, - language: "en", - }, - ], - transcript: "Hello world", - isFinal: true, - }); - - store.getState().handleTranscriptResponse(finalResponse); - expect(persist).toHaveBeenCalledTimes(1); - - const [words, hints] = persist.mock.calls[0] as [ - WordLike[], - RuntimeSpeakerHint[], + const finals = [ + makeFinalWord(" Hello", 100, 500), + makeFinalWord(" world", 550, 900), ]; - expect(words.map((word) => word.text)).toEqual([" Hello", " world"]); - expect(words.map((word) => word.end_ms)).toEqual([500, 1500]); - expect(hints).toEqual([ - { - data: { type: "provider_speaker_index", speaker_index: 0 }, - wordIndex: 0, - }, - ]); - store.getState().handleTranscriptResponse(finalResponse); + store.getState().handleTranscriptUpdate(finals, [], []); + expect(persist).toHaveBeenCalledTimes(1); - expect(store.getState().finalWordsMaxEndMsByChannel[0]).toBe(1500); + expect(persist).toHaveBeenCalledWith(finals, []); }); - test("adjusts partial hint indices after filtering partial words", () => { + test("does not call persist for empty finals", () => { + const store = createTranscriptStore(); const persist = vi.fn(); store.getState().setTranscriptPersist(persist); - const partialResponse = createResponse({ - words: [ - { - word: "hello", - punctuated_word: "Hello", - start: 0, - end: 0.5, - confidence: 1, - speaker: 0, - language: "en", - }, - { - word: "world", - punctuated_word: "world", - start: 0.5, - end: 1.0, - confidence: 1, - speaker: 1, - language: "en", - }, - { - word: "test", - punctuated_word: "test", - start: 1.1, - end: 1.5, - confidence: 1, - speaker: 0, - language: "en", - }, - ], - transcript: "Hello world test", - isFinal: false, - }); + store + .getState() + .handleTranscriptUpdate([], [], [makePartialWord(" partial", 100, 500)]); - store.getState().handleTranscriptResponse(partialResponse); + expect(persist).not.toHaveBeenCalled(); + }); - const stateAfterPartial = store.getState(); - expect(stateAfterPartial.partialWordsByChannel[0]).toHaveLength(3); - expect(stateAfterPartial.partialHintsByChannel[0]).toHaveLength(3); + test("atomic final + partial update", () => { + const store = createTranscriptStore(); + const persist = vi.fn(); + store.getState().setTranscriptPersist(persist); - const finalResponse = createResponse({ - words: [ - { - word: "hello", - punctuated_word: "Hello", - start: 0, - end: 0.5, - confidence: 1, - speaker: 0, - language: "en", - }, - { - word: "world", - punctuated_word: "world", - start: 0.5, - end: 1.0, - confidence: 1, - speaker: 1, - language: "en", - }, - ], - transcript: "Hello world", - isFinal: true, - }); + store + .getState() + .handleTranscriptUpdate( + [makeFinalWord(" Hello", 100, 500)], + [], + [ + makePartialWord(" world", 550, 900), + makePartialWord(" how", 950, 1200), + ], + ); - store.getState().handleTranscriptResponse(finalResponse); + expect(persist).toHaveBeenCalledTimes(1); + const state = store.getState(); + expect(state.partialWords).toHaveLength(2); + expect(state.partialWords.map((w) => w.text)).toEqual([" world", " how"]); + }); - const stateAfterFinal = store.getState(); - const remainingPartialWords = stateAfterFinal.partialWordsByChannel[0]; - const remainingHints = stateAfterFinal.partialHintsByChannel[0] ?? []; + test("reset clears partials and callback", () => { + const store = createTranscriptStore(); + const persist = vi.fn(); + store.getState().setTranscriptPersist(persist); - expect(remainingPartialWords).toHaveLength(1); - expect(remainingPartialWords?.[0]?.text).toBe(" test"); + store + .getState() + .handleTranscriptUpdate([], [], [makePartialWord(" hello", 100, 500)]); - expect(remainingHints).toHaveLength(1); - expect(remainingHints[0]?.wordIndex).toBe(0); + store.getState().resetTranscript(); - const hintedWord = - remainingPartialWords?.[remainingHints[0]?.wordIndex ?? -1]; - expect(hintedWord).toBeDefined(); - expect(hintedWord?.text).toBe(" test"); + const state = store.getState(); + expect(state.partialWords).toHaveLength(0); + expect(state.handlePersist).toBeUndefined(); }); }); diff --git a/apps/desktop/src/store/zustand/listener/transcript.ts b/apps/desktop/src/store/zustand/listener/transcript.ts index 6c7858dd56..90db690831 100644 --- a/apps/desktop/src/store/zustand/listener/transcript.ts +++ b/apps/desktop/src/store/zustand/listener/transcript.ts @@ -1,35 +1,34 @@ import { create as mutate } from "mutative"; import type { StoreApi } from "zustand"; -import type { StreamResponse } from "@hypr/plugin-listener"; - -import type { RuntimeSpeakerHint, WordLike } from "../../../utils/segment"; -import { transformWordEntries } from "./utils"; - -type WordsByChannel = Record; +import type { + PartialWord, + SpeakerHint, + TranscriptWord, +} from "@hypr/plugin-listener"; export type HandlePersistCallback = ( - words: WordLike[], - hints: RuntimeSpeakerHint[], + words: TranscriptWord[], + speakerHints: SpeakerHint[], ) => void; export type TranscriptState = { - finalWordsMaxEndMsByChannel: Record; - partialWordsByChannel: WordsByChannel; - partialHintsByChannel: Record; + partialWords: PartialWord[]; handlePersist?: HandlePersistCallback; }; export type TranscriptActions = { setTranscriptPersist: (callback?: HandlePersistCallback) => void; - handleTranscriptResponse: (response: StreamResponse) => void; + handleTranscriptUpdate: ( + newFinalWords: TranscriptWord[], + speakerHints: SpeakerHint[], + partialWords: PartialWord[], + ) => void; resetTranscript: () => void; }; const initialState: TranscriptState = { - finalWordsMaxEndMsByChannel: {}, - partialWordsByChannel: {}, - partialHintsByChannel: {}, + partialWords: [], handlePersist: undefined, }; @@ -39,111 +38,6 @@ export const createTranscriptSlice = < set: StoreApi["setState"], get: StoreApi["getState"], ): TranscriptState & TranscriptActions => { - const handleFinalWords = ( - channelIndex: number, - words: WordLike[], - hints: RuntimeSpeakerHint[], - ): void => { - const { - partialWordsByChannel, - partialHintsByChannel, - handlePersist, - finalWordsMaxEndMsByChannel, - } = get(); - - const lastPersistedEndMs = finalWordsMaxEndMsByChannel[channelIndex] ?? 0; - const lastEndMs = getLastEndMs(words); - - const firstNewWordIndex = words.findIndex( - (word) => word.end_ms > lastPersistedEndMs, - ); - if (firstNewWordIndex === -1) { - return; - } - - const newWords = words.slice(firstNewWordIndex); - const newHints = hints - .filter((hint) => hint.wordIndex >= firstNewWordIndex) - .map((hint) => ({ - ...hint, - wordIndex: hint.wordIndex - firstNewWordIndex, - })); - - const existingPartialWords = partialWordsByChannel[channelIndex] ?? []; - const remainingPartialWords = existingPartialWords.filter( - (word) => word.start_ms > lastEndMs, - ); - - const oldToNewIndex = new Map(); - let newIdx = 0; - for (let oldIdx = 0; oldIdx < existingPartialWords.length; oldIdx++) { - if (existingPartialWords[oldIdx].start_ms > lastEndMs) { - oldToNewIndex.set(oldIdx, newIdx); - newIdx++; - } - } - - const existingPartialHints = partialHintsByChannel[channelIndex] ?? []; - const remainingPartialHints = existingPartialHints - .filter((hint) => oldToNewIndex.has(hint.wordIndex)) - .map((hint) => ({ - ...hint, - wordIndex: oldToNewIndex.get(hint.wordIndex)!, - })); - - set((state) => - mutate(state, (draft) => { - draft.partialWordsByChannel[channelIndex] = remainingPartialWords; - draft.partialHintsByChannel[channelIndex] = remainingPartialHints; - draft.finalWordsMaxEndMsByChannel[channelIndex] = lastEndMs; - }), - ); - - handlePersist?.(newWords, newHints); - }; - - const handlePartialWords = ( - channelIndex: number, - words: WordLike[], - hints: RuntimeSpeakerHint[], - ): void => { - const { partialWordsByChannel, partialHintsByChannel } = get(); - const existing = partialWordsByChannel[channelIndex] ?? []; - - const firstStartMs = getFirstStartMs(words); - const lastEndMs = getLastEndMs(words); - - const [before, after] = [ - existing.filter((word) => word.end_ms <= firstStartMs), - existing.filter((word) => word.start_ms >= lastEndMs), - ]; - - const newWords = [...before, ...words, ...after]; - - const hintsWithAdjustedIndices = hints.map((hint) => ({ - ...hint, - wordIndex: before.length + hint.wordIndex, - })); - - const existingHints = partialHintsByChannel[channelIndex] ?? []; - const filteredOldHints = existingHints.filter((hint) => { - const word = existing[hint.wordIndex]; - return ( - word && (word.end_ms <= firstStartMs || word.start_ms >= lastEndMs) - ); - }); - - set((state) => - mutate(state, (draft) => { - draft.partialWordsByChannel[channelIndex] = newWords; - draft.partialHintsByChannel[channelIndex] = [ - ...filteredOldHints, - ...hintsWithAdjustedIndices, - ]; - }), - ); - }; - return { ...initialState, setTranscriptPersist: (callback) => { @@ -153,77 +47,26 @@ export const createTranscriptSlice = < }), ); }, - handleTranscriptResponse: (response) => { - if (response.type !== "Results") { - return; - } + handleTranscriptUpdate: (newFinalWords, speakerHints, partialWords) => { + const { handlePersist } = get(); - const channelIndex = response.channel_index[0]; - const alternative = response.channel.alternatives[0]; - if (channelIndex === undefined || !alternative) { - return; + if (newFinalWords.length > 0) { + handlePersist?.(newFinalWords, speakerHints); } - const [words, hints] = transformWordEntries( - alternative.words, - alternative.transcript, - channelIndex, + set((state) => + mutate(state, (draft) => { + draft.partialWords = partialWords; + }), ); - if (!words.length) { - return; - } - - if (response.is_final) { - handleFinalWords(channelIndex, words, hints); - } else { - handlePartialWords(channelIndex, words, hints); - } }, resetTranscript: () => { - const { partialWordsByChannel, partialHintsByChannel, handlePersist } = - get(); - - const remainingWords = Object.values(partialWordsByChannel).flat(); - - const channelIndices = Object.keys(partialWordsByChannel) - .map(Number) - .sort((a, b) => a - b); - - const offsetByChannel = new Map(); - let currentOffset = 0; - for (const channelIndex of channelIndices) { - offsetByChannel.set(channelIndex, currentOffset); - currentOffset += partialWordsByChannel[channelIndex]?.length ?? 0; - } - - const remainingHints: RuntimeSpeakerHint[] = []; - for (const channelIndex of channelIndices) { - const hints = partialHintsByChannel[channelIndex] ?? []; - const offset = offsetByChannel.get(channelIndex) ?? 0; - for (const hint of hints) { - remainingHints.push({ - ...hint, - wordIndex: hint.wordIndex + offset, - }); - } - } - - if (remainingWords.length > 0) { - handlePersist?.(remainingWords, remainingHints); - } - set((state) => mutate(state, (draft) => { - draft.partialWordsByChannel = {}; - draft.partialHintsByChannel = {}; - draft.finalWordsMaxEndMsByChannel = {}; + draft.partialWords = []; draft.handlePersist = undefined; }), ); }, }; }; - -const getLastEndMs = (words: WordLike[]): number => - words[words.length - 1]?.end_ms ?? 0; -const getFirstStartMs = (words: WordLike[]): number => words[0]?.start_ms ?? 0; diff --git a/crates/transcribe-proxy/tests/common/mod.rs b/crates/transcribe-proxy/tests/common/mod.rs index ca3bad2b6c..984f51d099 100644 --- a/crates/transcribe-proxy/tests/common/mod.rs +++ b/crates/transcribe-proxy/tests/common/mod.rs @@ -105,17 +105,25 @@ pub fn test_audio_stream_with_rate( Item = owhisper_interface::MixedMessage, > + Send + Unpin ++ 'static { + test_audio_stream_from_path(hypr_data::english_1::AUDIO_PATH, sample_rate) +} + +pub fn test_audio_stream_from_path( + path: &str, + sample_rate: u32, +) -> impl futures_util::Stream< + Item = owhisper_interface::MixedMessage, +> + Send ++ Unpin + 'static { use hypr_audio_utils::AudioFormatExt; - // chunk_samples should be proportional to sample_rate to maintain 100ms chunks let chunk_samples = (sample_rate / 10) as usize; - let audio = rodio::Decoder::new(std::io::BufReader::new( - std::fs::File::open(hypr_data::english_1::AUDIO_PATH).unwrap(), - )) - .unwrap() - .to_i16_le_chunks(sample_rate, chunk_samples); + let audio = rodio::Decoder::new(std::io::BufReader::new(std::fs::File::open(path).unwrap())) + .unwrap() + .to_i16_le_chunks(sample_rate, chunk_samples); Box::pin(tokio_stream::StreamExt::throttle( audio.map(owhisper_interface::MixedMessage::Audio), diff --git a/crates/transcribe-proxy/tests/record_fixtures.rs b/crates/transcribe-proxy/tests/record_fixtures.rs index 48a2da42b3..0e7435a484 100644 --- a/crates/transcribe-proxy/tests/record_fixtures.rs +++ b/crates/transcribe-proxy/tests/record_fixtures.rs @@ -4,6 +4,7 @@ use common::recording::{RecordingOptions, RecordingSession}; use common::*; use futures_util::StreamExt; +use std::path::Path; use std::time::Duration; use owhisper_client::Provider; @@ -12,11 +13,14 @@ use owhisper_interface::stream::StreamResponse; async fn record_live_fixture( provider: Provider, + audio_path: &str, + languages: Vec, recording_opts: RecordingOptions, - sample_rate: u32, + json_array_output: Option<&Path>, ) { let _ = tracing_subscriber::fmt::try_init(); + let sample_rate = provider.default_live_sample_rate(); let api_key = std::env::var(provider.env_key_name()) .unwrap_or_else(|_| panic!("{} must be set", provider.env_key_name())); let addr = start_server_with_provider(provider, api_key).await; @@ -32,7 +36,7 @@ async fn record_live_fixture( .api_base(format!("http://{}", addr)) .params(owhisper_interface::ListenParams { model: Some(provider.default_live_model().to_string()), - languages: vec![hypr_language::ISO639::En.into()], + languages, sample_rate, ..Default::default() }) @@ -40,12 +44,13 @@ async fn record_live_fixture( .await; let provider_name = format!("record:{}", provider); - let input = test_audio_stream_with_rate(sample_rate); + let input = test_audio_stream_from_path(audio_path, sample_rate); let (stream, handle) = client.from_realtime_audio(input).await.unwrap(); futures_util::pin_mut!(stream); + let mut responses: Vec = Vec::new(); let mut saw_transcript = false; - let timeout = Duration::from_secs(30); + let timeout = Duration::from_secs(120); let test_future = async { while let Some(result) = stream.next().await { @@ -68,6 +73,8 @@ async fn record_live_fixture( } } } + + responses.push(response); } Err(e) => { panic!("[{}] error: {:?}", provider_name, e); @@ -89,6 +96,15 @@ async fn record_live_fixture( } } + if let Some(output_path) = json_array_output { + if let Some(parent) = output_path.parent() { + std::fs::create_dir_all(parent).expect("failed to create output directory"); + } + let json = serde_json::to_string_pretty(&responses).expect("failed to serialize responses"); + std::fs::write(output_path, json).expect("failed to write fixture"); + println!("[{}] Fixture saved to {:?}", provider_name, output_path); + } + assert!( saw_transcript, "[{}] expected at least one non-empty transcript", @@ -98,14 +114,38 @@ async fn record_live_fixture( macro_rules! record_fixture_test { ($name:ident, $adapter:ty, $provider:expr) => { + record_fixture_test!( + $name, $adapter, $provider, + hypr_data::english_1::AUDIO_PATH, + vec![hypr_language::ISO639::En.into()], + @no_output + ); + }; + ($name:ident, $adapter:ty, $provider:expr, $audio:expr, $langs:expr, $output:literal) => { + #[ignore] + #[tokio::test] + async fn $name() { + let output_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join($output); + record_live_fixture::<$adapter>( + $provider, + $audio, + $langs, + RecordingOptions::from_env("normal"), + Some(&output_path), + ) + .await; + } + }; + ($name:ident, $adapter:ty, $provider:expr, $audio:expr, $langs:expr, @no_output) => { #[ignore] #[tokio::test] async fn $name() { - let sample_rate = $provider.default_live_sample_rate(); record_live_fixture::<$adapter>( $provider, + $audio, + $langs, RecordingOptions::from_env("normal"), - sample_rate, + None, ) .await; } @@ -138,4 +178,13 @@ mod record { owhisper_client::ElevenLabsAdapter, Provider::ElevenLabs ); + + record_fixture_test!( + soniox_korean, + owhisper_client::SonioxAdapter, + Provider::Soniox, + hypr_data::korean_1::AUDIO_PATH, + vec![hypr_language::ISO639::Ko.into()], + "../../crates/transcript/src/accumulator/fixtures/soniox_2.json" + ); } diff --git a/crates/transcript/Cargo.toml b/crates/transcript/Cargo.toml new file mode 100644 index 0000000000..9fd1d63589 --- /dev/null +++ b/crates/transcript/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "transcript" +version = "0.1.0" +edition = "2024" + +[dev-dependencies] +hypr-data = { workspace = true } +serde_json = { workspace = true } + +[dependencies] +owhisper-interface = { workspace = true } + +serde = { workspace = true } +specta = { workspace = true } +uuid = { workspace = true, features = ["v4"] } diff --git a/crates/transcript/src/accumulator/channel.rs b/crates/transcript/src/accumulator/channel.rs new file mode 100644 index 0000000000..10fc6785ad --- /dev/null +++ b/crates/transcript/src/accumulator/channel.rs @@ -0,0 +1,52 @@ +use super::words::{ + RawWord, SpeakerHint, TranscriptWord, dedup, finalize_words, splice, stitch, strip_overlap, +}; + +pub(super) struct ChannelState { + watermark: i64, + held: Option, + partials: Vec, +} + +impl ChannelState { + pub(super) fn new() -> Self { + Self { + watermark: 0, + held: None, + partials: Vec::new(), + } + } + + pub(super) fn partials(&self) -> &[RawWord] { + &self.partials + } + + pub(super) fn apply_final( + &mut self, + words: Vec, + ) -> (Vec, Vec) { + let response_end = words.last().map_or(0, |w| w.end_ms); + let new_words: Vec<_> = dedup(words, self.watermark); + + if new_words.is_empty() { + return (vec![], vec![]); + } + + self.watermark = response_end; + self.partials = strip_overlap(std::mem::take(&mut self.partials), response_end); + + let (emitted, held) = stitch(self.held.take(), new_words); + self.held = held; + finalize_words(emitted) + } + + pub(super) fn apply_partial(&mut self, words: Vec) { + self.partials = splice(&self.partials, words); + } + + pub(super) fn drain(&mut self) -> (Vec, Vec) { + let mut raw: Vec<_> = self.held.take().into_iter().collect(); + raw.extend(std::mem::take(&mut self.partials)); + finalize_words(raw) + } +} diff --git a/crates/transcript/src/accumulator/mod.rs b/crates/transcript/src/accumulator/mod.rs new file mode 100644 index 0000000000..847da5c244 --- /dev/null +++ b/crates/transcript/src/accumulator/mod.rs @@ -0,0 +1,507 @@ +//! # Transcript-as-Oracle Accumulator +//! +//! The transcript string in each ASR response is the **sole source of truth** +//! for word boundaries. Tokens are sub-word fragments with timing metadata; +//! the transcript tells us which fragments belong to the same word. +//! +//! ## Two-level design +//! +//! **Within a response** — `assemble` aligns tokens to the transcript via +//! `spacing_from_transcript`. A space in the transcript means "new word"; +//! no space means "same word." No timing heuristics. +//! +//! **Across responses** — `stitch` handles the one case where no transcript +//! spans both sides: when a provider splits a word across two final responses +//! (e.g. Korean particles like "시스템" + "을" → "시스템을"). This uses a +//! timing-based heuristic because no cross-response transcript exists. + +mod channel; +mod words; + +use std::collections::BTreeMap; + +use owhisper_interface::stream::StreamResponse; + +pub use words::{PartialWord, SpeakerHint, TranscriptUpdate, TranscriptWord}; + +use channel::ChannelState; +use words::{assemble, ensure_space_prefix_partial}; + +/// Accumulates streaming ASR responses into clean, deduplicated transcript data. +/// +/// Each `process` call returns a `TranscriptUpdate` with: +/// - `new_final_words`: words that became final since the last update (ready to persist) +/// - `speaker_hints`: speaker associations for the newly finalized words +/// - `partial_words`: current in-progress words across all channels (for live display) +/// +/// Call `flush` at session end to drain any held/partial words that were never finalized. +pub struct TranscriptAccumulator { + channels: BTreeMap, +} + +impl TranscriptAccumulator { + pub fn new() -> Self { + Self { + channels: BTreeMap::new(), + } + } + + pub fn process(&mut self, response: &StreamResponse) -> Option { + let (is_final, channel, channel_index) = match response { + StreamResponse::TranscriptResponse { + is_final, + channel, + channel_index, + .. + } => (*is_final, channel, channel_index), + _ => return None, + }; + + let alt = channel.alternatives.first()?; + if alt.words.is_empty() && alt.transcript.is_empty() { + return None; + } + + let ch = channel_index.first().copied().unwrap_or(0); + let words = assemble(&alt.words, &alt.transcript, ch); + if words.is_empty() { + return None; + } + + let state = self.channels.entry(ch).or_insert_with(ChannelState::new); + + let (new_final_words, speaker_hints) = if is_final { + state.apply_final(words) + } else { + state.apply_partial(words); + (vec![], vec![]) + }; + + Some(TranscriptUpdate { + new_final_words, + speaker_hints, + partial_words: self.all_partials(), + }) + } + + pub fn flush(&mut self) -> TranscriptUpdate { + let mut new_final_words = Vec::new(); + let mut speaker_hints = Vec::new(); + + for state in self.channels.values_mut() { + let (words, hints) = state.drain(); + new_final_words.extend(words); + speaker_hints.extend(hints); + } + + TranscriptUpdate { + new_final_words, + speaker_hints, + partial_words: vec![], + } + } + + fn all_partials(&self) -> Vec { + let mut partials: Vec = self + .channels + .values() + .flat_map(|state| state.partials().iter().map(|w| w.to_partial())) + .collect(); + + if let Some(first) = partials.first_mut() { + ensure_space_prefix_partial(first); + } + + partials + } +} + +impl Default for TranscriptAccumulator { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use owhisper_interface::stream::{Alternatives, Channel, Metadata, ModelInfo}; + + fn raw_word( + text: &str, + start: f64, + end: f64, + speaker: Option, + ) -> owhisper_interface::stream::Word { + owhisper_interface::stream::Word { + word: text.to_string(), + start, + end, + confidence: 1.0, + speaker, + punctuated_word: Some(text.to_string()), + language: None, + } + } + + fn response( + words: &[(&str, f64, f64, Option)], + transcript: &str, + is_final: bool, + channel_idx: i32, + ) -> StreamResponse { + StreamResponse::TranscriptResponse { + start: 0.0, + duration: 0.0, + is_final, + speech_final: is_final, + from_finalize: false, + channel: Channel { + alternatives: vec![Alternatives { + transcript: transcript.to_string(), + words: words + .iter() + .map(|&(t, s, e, sp)| raw_word(t, s, e, sp)) + .collect(), + confidence: 1.0, + languages: vec![], + }], + }, + metadata: Metadata { + request_id: String::new(), + model_info: ModelInfo { + name: String::new(), + version: String::new(), + arch: String::new(), + }, + model_uuid: String::new(), + extra: None, + }, + channel_index: vec![channel_idx], + } + } + + fn partial(words: &[(&str, f64, f64)], transcript: &str) -> StreamResponse { + let ws: Vec<_> = words.iter().map(|&(t, s, e)| (t, s, e, None)).collect(); + response(&ws, transcript, false, 0) + } + + fn finalize(words: &[(&str, f64, f64)], transcript: &str) -> StreamResponse { + let ws: Vec<_> = words.iter().map(|&(t, s, e)| (t, s, e, None)).collect(); + response(&ws, transcript, true, 0) + } + + fn finalize_with_speakers( + words: &[(&str, f64, f64, Option)], + transcript: &str, + ) -> StreamResponse { + response(words, transcript, true, 0) + } + + fn replay(responses: &[StreamResponse]) -> Vec { + let mut acc = TranscriptAccumulator::new(); + let mut words = Vec::new(); + + for r in responses { + if let Some(update) = acc.process(r) { + words.extend(update.new_final_words); + } + } + + words.extend(acc.flush().new_final_words); + words + } + + fn assert_valid_output(words: &[TranscriptWord]) { + assert!(!words.is_empty(), "must produce words"); + + assert!( + words.iter().all(|w| !w.id.is_empty()), + "all words must have IDs" + ); + + let ids: std::collections::HashSet<_> = words.iter().map(|w| &w.id).collect(); + assert_eq!(ids.len(), words.len(), "IDs must be unique"); + + for w in words { + assert!( + !w.text.trim().is_empty(), + "word text must not be blank: {w:?}" + ); + assert!( + w.text.starts_with(' '), + "word must start with space: {:?}", + w.text + ); + } + + for ch in words + .iter() + .map(|w| w.channel) + .collect::>() + { + let cw: Vec<_> = words.iter().filter(|w| w.channel == ch).collect(); + assert!( + cw.windows(2).all(|w| w[0].start_ms <= w[1].start_ms), + "channel {ch} must be chronological" + ); + } + } + + #[test] + fn partial_update_exposes_current_words() { + let mut acc = TranscriptAccumulator::new(); + + let update = acc + .process(&partial( + &[(" Hello", 0.1, 0.5), (" world", 0.6, 0.9)], + " Hello world", + )) + .unwrap(); + + assert!(update.new_final_words.is_empty()); + assert_eq!(update.partial_words.len(), 2); + assert_eq!( + update + .partial_words + .iter() + .map(|w| &w.text) + .collect::>(), + [" Hello", " world"] + ); + } + + #[test] + fn partial_splices_into_existing_window() { + let mut acc = TranscriptAccumulator::new(); + + acc.process(&partial( + &[(" Hello", 0.1, 0.5), (" world", 0.6, 0.9)], + " Hello world", + )); + + let update = acc + .process(&partial( + &[ + (" Hello", 0.1, 0.5), + (" world", 0.6, 0.9), + (" today", 1.0, 1.3), + ], + " Hello world today", + )) + .unwrap(); + + assert_eq!(update.partial_words.len(), 3); + assert_eq!( + update + .partial_words + .iter() + .map(|w| &w.text) + .collect::>(), + [" Hello", " world", " today"] + ); + } + + #[test] + fn final_emits_prefix_and_holds_last() { + let mut acc = TranscriptAccumulator::new(); + + acc.process(&partial( + &[(" Hello", 0.1, 0.5), (" world", 0.55, 0.9)], + " Hello world", + )); + + let update = acc + .process(&finalize( + &[(" Hello", 0.1, 0.5), (" world", 0.55, 0.9)], + " Hello world", + )) + .unwrap(); + + assert_eq!(update.new_final_words.len(), 1); + assert_eq!(update.new_final_words[0].text, " Hello"); + assert!(update.partial_words.is_empty()); + + let flushed = acc.flush(); + assert_eq!(flushed.new_final_words.len(), 1); + assert_eq!(flushed.new_final_words[0].text, " world"); + } + + #[test] + fn final_deduplicates_repeated_response() { + let mut acc = TranscriptAccumulator::new(); + + let r = finalize( + &[(" Hello", 0.1, 0.5), (" world", 0.6, 0.9)], + " Hello world", + ); + + let first = acc.process(&r).unwrap(); + let second = acc.process(&r).unwrap(); + + assert!(!first.new_final_words.is_empty()); + assert!(second.new_final_words.is_empty()); + } + + #[test] + fn final_clears_overlapping_partials() { + let mut acc = TranscriptAccumulator::new(); + + acc.process(&partial( + &[ + (" Hello", 0.1, 0.5), + (" world", 0.6, 1.0), + (" test", 1.1, 1.5), + ], + " Hello world test", + )); + + let update = acc + .process(&finalize( + &[(" Hello", 0.1, 0.5), (" world", 0.6, 1.0)], + " Hello world", + )) + .unwrap(); + + assert_eq!(update.partial_words.len(), 1); + assert_eq!(update.partial_words[0].text, " test"); + } + + #[test] + fn all_final_words_have_ids() { + let mut acc = TranscriptAccumulator::new(); + + let update = acc + .process(&finalize( + &[(" Hello", 0.1, 0.5), (" world", 0.6, 0.9)], + " Hello world", + )) + .unwrap(); + + assert!(update.new_final_words.iter().all(|w| !w.id.is_empty())); + + let flushed = acc.flush(); + assert!(flushed.new_final_words.iter().all(|w| !w.id.is_empty())); + } + + #[test] + fn flush_drains_held_word() { + let mut acc = TranscriptAccumulator::new(); + + acc.process(&finalize( + &[(" Hello", 0.1, 0.5), (" world", 0.6, 0.9)], + " Hello world", + )); + + let flushed = acc.flush(); + + assert_eq!(flushed.new_final_words.len(), 1); + assert_eq!(flushed.new_final_words[0].text, " world"); + assert!(!flushed.new_final_words[0].id.is_empty()); + } + + #[test] + fn flush_drains_partials_beyond_final_range() { + let mut acc = TranscriptAccumulator::new(); + + acc.process(&partial(&[(" later", 5.0, 5.5)], " later")); + + acc.process(&finalize( + &[(" Hello", 0.1, 0.5), (" world", 0.6, 0.9)], + " Hello world", + )); + + let flushed = acc.flush(); + + let texts: Vec<_> = flushed.new_final_words.iter().map(|w| &w.text).collect(); + assert!( + texts.contains(&&" world".to_string()) || texts.contains(&&" later".to_string()), + "flush must drain held + partials: {texts:?}" + ); + assert!(flushed.new_final_words.iter().all(|w| !w.id.is_empty())); + } + + #[test] + fn flush_on_empty_accumulator_is_empty() { + let mut acc = TranscriptAccumulator::new(); + let flushed = acc.flush(); + assert!(flushed.new_final_words.is_empty()); + assert!(flushed.partial_words.is_empty()); + assert!(flushed.speaker_hints.is_empty()); + } + + #[test] + fn non_transcript_responses_produce_no_update() { + let mut acc = TranscriptAccumulator::new(); + let ignored = StreamResponse::TerminalResponse { + request_id: "r".into(), + created: "now".into(), + duration: 1.0, + channels: 1, + }; + assert!(acc.process(&ignored).is_none()); + } + + #[test] + fn speaker_hints_extracted_from_final_words() { + let mut acc = TranscriptAccumulator::new(); + + let update = acc + .process(&finalize_with_speakers( + &[(" Hello", 0.1, 0.5, Some(0)), (" world", 0.6, 0.9, Some(1))], + " Hello world", + )) + .unwrap(); + + assert_eq!(update.new_final_words.len(), 1); + assert_eq!(update.speaker_hints.len(), 1); + assert_eq!(update.speaker_hints[0].speaker_index, 0); + assert_eq!( + update.speaker_hints[0].word_id, + update.new_final_words[0].id + ); + + let flushed = acc.flush(); + assert_eq!(flushed.new_final_words.len(), 1); + assert_eq!(flushed.speaker_hints.len(), 1); + assert_eq!(flushed.speaker_hints[0].speaker_index, 1); + } + + #[test] + fn no_speaker_hints_when_speaker_is_none() { + let mut acc = TranscriptAccumulator::new(); + + let update = acc + .process(&finalize( + &[(" Hello", 0.1, 0.5), (" world", 0.6, 0.9)], + " Hello world", + )) + .unwrap(); + + assert!(update.speaker_hints.is_empty()); + } + + macro_rules! fixture_test { + ($test_name:ident, $json:expr) => { + #[test] + fn $test_name() { + let responses: Vec = + serde_json::from_str($json).expect("fixture must parse as StreamResponse[]"); + assert_valid_output(&replay(&responses)); + } + }; + } + + fixture_test!( + deepgram_fixture_produces_valid_output, + hypr_data::english_1::DEEPGRAM_JSON + ); + fixture_test!( + soniox_fixture_produces_valid_output, + hypr_data::english_1::SONIOX_JSON + ); + fixture_test!( + soniox_korean_fixture_produces_valid_output, + hypr_data::korean_1::SONIOX_JSON + ); +} diff --git a/crates/transcript/src/accumulator/words.rs b/crates/transcript/src/accumulator/words.rs new file mode 100644 index 0000000000..49d9238c8f --- /dev/null +++ b/crates/transcript/src/accumulator/words.rs @@ -0,0 +1,511 @@ +use owhisper_interface::stream::Word; +use uuid::Uuid; + +// ── Public output types ───────────────────────────────────────────────────── + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, specta::Type)] +pub struct TranscriptWord { + pub id: String, + pub text: String, + pub start_ms: i64, + pub end_ms: i64, + pub channel: i32, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, specta::Type)] +pub struct PartialWord { + pub text: String, + pub start_ms: i64, + pub end_ms: i64, + pub channel: i32, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, specta::Type)] +pub struct SpeakerHint { + pub word_id: String, + pub speaker_index: i32, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, specta::Type)] +pub struct TranscriptUpdate { + pub new_final_words: Vec, + pub speaker_hints: Vec, + pub partial_words: Vec, +} + +// ── Internal pipeline type ────────────────────────────────────────────────── + +#[derive(Debug, Clone)] +pub(super) struct RawWord { + pub(super) text: String, + pub(super) start_ms: i64, + pub(super) end_ms: i64, + pub(super) channel: i32, + pub(super) speaker: Option, +} + +impl RawWord { + pub(super) fn to_final(self, id: String) -> (TranscriptWord, Option) { + let hint = self.speaker.map(|speaker_index| SpeakerHint { + word_id: id.clone(), + speaker_index, + }); + let word = TranscriptWord { + id, + text: self.text, + start_ms: self.start_ms, + end_ms: self.end_ms, + channel: self.channel, + }; + (word, hint) + } + + pub(super) fn to_partial(&self) -> PartialWord { + PartialWord { + text: self.text.clone(), + start_ms: self.start_ms, + end_ms: self.end_ms, + channel: self.channel, + } + } +} + +// ── Assembly ───────────────────────────────────────────────────────────────── + +/// Assemble raw ASR tokens into merged `RawWord`s. +/// +/// The transcript string is the **sole oracle** for word boundaries within a +/// single response. `spacing_from_transcript` aligns each token to the +/// transcript; a space prefix means "new word", no space means "same word." +/// Adjacent tokens without a space prefix are unconditionally merged — +/// no timing heuristics. +pub(super) fn assemble(raw: &[Word], transcript: &str, channel: i32) -> Vec { + let spaced = spacing_from_transcript(raw, transcript); + let mut result: Vec = Vec::new(); + + for (w, text) in raw.iter().zip(&spaced) { + let start_ms = (w.start * 1000.0).round() as i64; + let end_ms = (w.end * 1000.0).round() as i64; + + let should_merge = !text.starts_with(' ') && result.last().is_some(); + + if should_merge { + let last = result.last_mut().unwrap(); + last.text.push_str(text); + last.end_ms = end_ms; + if last.speaker.is_none() { + last.speaker = w.speaker; + } + } else { + result.push(RawWord { + text: text.clone(), + start_ms, + end_ms, + channel, + speaker: w.speaker, + }); + } + } + + result +} + +/// Align each token to the transcript string and recover its spacing. +/// +/// The transcript is the oracle: if a token is found in the transcript, the +/// whitespace between the previous match and this one is prepended verbatim. +/// If a token cannot be found (ASR/transcript mismatch), a space is forced +/// so it becomes a separate word — "unknown = word boundary." +fn spacing_from_transcript(raw: &[Word], transcript: &str) -> Vec { + let mut result = Vec::with_capacity(raw.len()); + let mut pos = 0; + + for w in raw { + let text = w.punctuated_word.as_deref().unwrap_or(&w.word); + let trimmed = text.trim(); + + if trimmed.is_empty() { + result.push(text.to_string()); + continue; + } + + match transcript[pos..].find(trimmed) { + Some(found) => { + let abs = pos + found; + result.push(format!("{}{trimmed}", &transcript[pos..abs])); + pos = abs + trimmed.len(); + } + None => { + let mut fallback = text.to_string(); + if !fallback.starts_with(' ') { + fallback.insert(0, ' '); + } + result.push(fallback); + } + } + } + + result +} + +// ── Pipeline stages ────────────────────────────────────────────────────────── + +/// Drop words already covered by the watermark (deduplication). +pub(super) fn dedup(words: Vec, watermark: i64) -> Vec { + words + .into_iter() + .skip_while(|w| w.end_ms <= watermark) + .collect() +} + +/// Cross-response word boundary handling — the one place where a timing +/// heuristic is unavoidable, because no transcript spans both responses. +/// +/// Holds back the last word of each finalized batch so it can be merged +/// with the first word of the next batch if the provider split a word +/// across responses (common with Korean particles, contractions, etc.). +pub(super) fn stitch( + held: Option, + mut words: Vec, +) -> (Vec, Option) { + if words.is_empty() { + return (held.into_iter().collect(), None); + } + + if let Some(h) = held { + if should_stitch(&h, &words[0]) { + words[0] = merge_words(h, words[0].clone()); + } else { + words.insert(0, h); + } + } + + let new_held = words.pop(); + (words, new_held) +} + +/// Replace the time range covered by `incoming` within `existing`. +pub(super) fn splice(existing: &[RawWord], incoming: Vec) -> Vec { + let first_start = incoming.first().map_or(0, |w| w.start_ms); + let last_end = incoming.last().map_or(0, |w| w.end_ms); + + existing + .iter() + .filter(|w| w.end_ms <= first_start) + .cloned() + .chain(incoming) + .chain(existing.iter().filter(|w| w.start_ms >= last_end).cloned()) + .collect() +} + +/// Remove partials that overlap with the finalized time range. +pub(super) fn strip_overlap(partials: Vec, final_end: i64) -> Vec { + partials + .into_iter() + .filter(|w| w.start_ms > final_end) + .collect() +} + +// ── Word-level transforms ──────────────────────────────────────────────────── + +pub(super) fn ensure_space_prefix_raw(w: &mut RawWord) { + if !w.text.starts_with(' ') { + w.text.insert(0, ' '); + } +} + +pub(super) fn ensure_space_prefix_partial(w: &mut PartialWord) { + if !w.text.starts_with(' ') { + w.text.insert(0, ' '); + } +} + +fn should_stitch(tail: &RawWord, head: &RawWord) -> bool { + !head.text.starts_with(' ') && (head.start_ms - tail.end_ms) <= 300 +} + +fn merge_words(mut left: RawWord, right: RawWord) -> RawWord { + left.text.push_str(&right.text); + left.end_ms = right.end_ms; + if left.speaker.is_none() { + left.speaker = right.speaker; + } + left +} + +/// Convert a list of RawWords into finalized TranscriptWords + SpeakerHints. +/// Assigns UUIDs, ensures space prefixes, and extracts speaker data. +pub(super) fn finalize_words(mut words: Vec) -> (Vec, Vec) { + words.iter_mut().for_each(ensure_space_prefix_raw); + + let mut final_words = Vec::with_capacity(words.len()); + let mut hints = Vec::new(); + + for w in words { + let id = Uuid::new_v4().to_string(); + let (word, hint) = w.to_final(id); + final_words.push(word); + if let Some(h) = hint { + hints.push(h); + } + } + + (final_words, hints) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn raw_word(text: &str, start: f64, end: f64) -> Word { + Word { + word: text.to_string(), + start, + end, + confidence: 1.0, + speaker: None, + punctuated_word: Some(text.to_string()), + language: None, + } + } + + fn word(text: &str, start_ms: i64, end_ms: i64) -> RawWord { + RawWord { + text: text.to_string(), + start_ms, + end_ms, + channel: 0, + speaker: None, + } + } + + // ── spacing_from_transcript ────────────────────────────────────────── + + #[test] + fn spacing_recovered_from_transcript() { + let raw = vec![raw_word("Hello", 0.0, 0.5), raw_word("world", 0.6, 1.0)]; + let spaced = spacing_from_transcript(&raw, " Hello world"); + assert_eq!(spaced, [" Hello", " world"]); + } + + #[test] + fn spacing_forces_word_boundary_on_unfound_token() { + let raw = vec![raw_word("Hello", 0.0, 0.5)]; + let spaced = spacing_from_transcript(&raw, "completely different"); + assert_eq!(spaced, [" Hello"]); + } + + #[test] + fn spacing_preserves_no_space_at_transcript_start() { + let raw = vec![raw_word("기", 0.0, 0.1), raw_word("간", 0.2, 0.3)]; + let spaced = spacing_from_transcript(&raw, "기간"); + assert_eq!(spaced, ["기", "간"]); + } + + // ── assemble ───────────────────────────────────────────────────────── + + #[test] + fn assemble_merges_attached_punctuation() { + let raw = vec![raw_word(" Hello", 0.0, 0.5), raw_word("'s", 0.51, 0.6)]; + let words = assemble(&raw, " Hello's", 0); + assert_eq!(words.len(), 1); + assert_eq!(words[0].text, " Hello's"); + assert_eq!(words[0].end_ms, 600); + } + + #[test] + fn assemble_does_not_merge_spaced_tokens() { + let raw = vec![raw_word(" Hello", 0.0, 0.5), raw_word(" world", 0.51, 1.0)]; + let words = assemble(&raw, " Hello world", 0); + assert_eq!(words.len(), 2); + } + + #[test] + fn assemble_separates_unfound_tokens() { + let raw = vec![raw_word("Hello", 0.0, 0.5), raw_word("world", 0.51, 0.6)]; + let words = assemble(&raw, "completely different text", 0); + assert_eq!(words.len(), 2); + assert!(words[0].text.starts_with(' ')); + assert!(words[1].text.starts_with(' ')); + } + + #[test] + fn assemble_merges_cjk_syllables_with_large_gap() { + let raw = vec![ + raw_word("있는", 0.0, 0.3), + raw_word("데", 0.54, 0.66), + raw_word(",", 0.84, 0.9), + ]; + let words = assemble(&raw, " 있는데,", 0); + assert_eq!( + words.len(), + 1, + "syllables in same CJK word must merge: {words:?}" + ); + assert_eq!(words[0].text, " 있는데,"); + assert_eq!(words[0].end_ms, 900); + } + + #[test] + fn assemble_splits_cjk_words_at_transcript_space_boundary() { + let raw = vec![ + raw_word("있는", 0.0, 0.3), + raw_word("데", 0.54, 0.66), + raw_word("학습", 1.0, 1.3), + raw_word("과", 1.54, 1.66), + ]; + let words = assemble(&raw, " 있는데 학습과", 0); + assert_eq!( + words.len(), + 2, + "space in transcript must split words: {words:?}" + ); + assert_eq!(words[0].text, " 있는데"); + assert_eq!(words[1].text, " 학습과"); + } + + // ── dedup ──────────────────────────────────────────────────────────── + + #[test] + fn dedup_drops_words_at_or_before_watermark() { + let words = vec![ + word(" a", 0, 100), + word(" b", 100, 200), + word(" c", 200, 300), + ]; + let result = dedup(words, 200); + assert_eq!(result.len(), 1); + assert_eq!(result[0].text, " c"); + } + + #[test] + fn dedup_keeps_all_when_watermark_is_zero() { + let words = vec![word(" a", 0, 100), word(" b", 100, 200)]; + let result = dedup(words, 0); + assert_eq!(result.len(), 2); + } + + #[test] + fn dedup_returns_empty_when_all_covered() { + let words = vec![word(" a", 0, 100), word(" b", 100, 200)]; + let result = dedup(words, 200); + assert!(result.is_empty()); + } + + // ── stitch ─────────────────────────────────────────────────────────── + + #[test] + fn stitch_no_held_holds_last() { + let ws = vec![word(" Hello", 0, 500), word(" world", 600, 1000)]; + let (emitted, held) = stitch(None, ws); + assert_eq!(emitted.len(), 1); + assert_eq!(emitted[0].text, " Hello"); + assert_eq!(held.unwrap().text, " world"); + } + + #[test] + fn stitch_merges_spaceless_adjacent_head() { + let held = word(" Hello", 0, 500); + let ws = vec![word("'s", 550, 700)]; + let (emitted, held) = stitch(Some(held), ws); + assert!(emitted.is_empty()); + assert_eq!(held.unwrap().text, " Hello's"); + } + + #[test] + fn stitch_separates_spaced_head() { + let held = word(" Hello", 0, 500); + let ws = vec![word(" world", 600, 1000)]; + let (emitted, held) = stitch(Some(held), ws); + assert_eq!(emitted.len(), 1); + assert_eq!(emitted[0].text, " Hello"); + assert_eq!(held.unwrap().text, " world"); + } + + #[test] + fn stitch_separates_distant_spaceless_head() { + let held = word(" Hello", 0, 500); + let ws = vec![word("world", 1000, 1500)]; + let (emitted, held) = stitch(Some(held), ws); + assert_eq!(emitted.len(), 1); + assert_eq!(emitted[0].text, " Hello"); + assert_eq!(held.unwrap().text, "world"); + } + + #[test] + fn stitch_empty_batch_releases_held() { + let held = word(" Hello", 0, 500); + let (emitted, held) = stitch(Some(held), vec![]); + assert_eq!(emitted.len(), 1); + assert!(held.is_none()); + } + + #[test] + fn stitch_single_word_batch_yields_no_emission() { + let ws = vec![word(" Hello", 0, 500)]; + let (emitted, held) = stitch(None, ws); + assert!(emitted.is_empty()); + assert_eq!(held.unwrap().text, " Hello"); + } + + // ── splice ─────────────────────────────────────────────────────────── + + #[test] + fn splice_replaces_overlapping_range() { + let existing = vec![ + word(" a", 0, 100), + word(" b", 100, 200), + word(" c", 300, 400), + ]; + let incoming = vec![word(" B", 100, 200), word(" new", 200, 300)]; + let result = splice(&existing, incoming); + assert_eq!( + result.iter().map(|w| &w.text[..]).collect::>(), + [" a", " B", " new", " c"] + ); + } + + #[test] + fn splice_appends_when_no_overlap() { + let existing = vec![word(" a", 0, 100)]; + let incoming = vec![word(" b", 200, 300)]; + let result = splice(&existing, incoming); + assert_eq!(result.len(), 2); + } + + #[test] + fn splice_full_replacement() { + let existing = vec![word(" a", 0, 100), word(" b", 100, 200)]; + let incoming = vec![ + word(" x", 0, 100), + word(" y", 100, 200), + word(" z", 200, 300), + ]; + let result = splice(&existing, incoming); + assert_eq!( + result.iter().map(|w| &w.text[..]).collect::>(), + [" x", " y", " z"] + ); + } + + // ── strip_overlap ──────────────────────────────────────────────────── + + #[test] + fn strip_overlap_removes_covered_partials() { + let partials = vec![ + word(" a", 0, 100), + word(" b", 100, 200), + word(" c", 300, 400), + ]; + let result = strip_overlap(partials, 200); + assert_eq!(result.len(), 1); + assert_eq!(result[0].text, " c"); + } + + #[test] + fn strip_overlap_keeps_all_beyond_range() { + let partials = vec![word(" a", 300, 400), word(" b", 400, 500)]; + let result = strip_overlap(partials, 200); + assert_eq!(result.len(), 2); + } +} diff --git a/crates/transcript/src/lib.rs b/crates/transcript/src/lib.rs new file mode 100644 index 0000000000..2d5b405aa4 --- /dev/null +++ b/crates/transcript/src/lib.rs @@ -0,0 +1 @@ +pub mod accumulator; diff --git a/plugins/listener/Cargo.toml b/plugins/listener/Cargo.toml index 171bc927ab..57ab253489 100644 --- a/plugins/listener/Cargo.toml +++ b/plugins/listener/Cargo.toml @@ -32,6 +32,8 @@ hypr-mac = { workspace = true } hypr-vad-ext = { workspace = true } tauri-plugin-fs-sync = { workspace = true } +hypr-transcript = { workspace = true } + owhisper-client = { workspace = true } owhisper-interface = { workspace = true } diff --git a/plugins/listener/js/bindings.gen.ts b/plugins/listener/js/bindings.gen.ts index 5dc56ebaaa..3f23d01228 100644 --- a/plugins/listener/js/bindings.gen.ts +++ b/plugins/listener/js/bindings.gen.ts @@ -1,185 +1,326 @@ // @ts-nocheck +/** tauri-specta globals **/ +import { + Channel as TAURI_CHANNEL, + invoke as TAURI_INVOKE, +} from "@tauri-apps/api/core"; +import * as TAURI_API_EVENT from "@tauri-apps/api/event"; +import { type WebviewWindow as __WebviewWindow__ } from "@tauri-apps/api/webviewWindow"; // This file was generated by [tauri-specta](https://github.com/oscartbeaumont/tauri-specta). Do not edit this file manually. /** user-defined commands **/ - export const commands = { -async listMicrophoneDevices() : Promise> { + async listMicrophoneDevices(): Promise> { try { - return { status: "ok", data: await TAURI_INVOKE("plugin:listener|list_microphone_devices") }; -} catch (e) { - if(e instanceof Error) throw e; - else return { status: "error", error: e as any }; -} -}, -async getCurrentMicrophoneDevice() : Promise> { + return { + status: "ok", + data: await TAURI_INVOKE("plugin:listener|list_microphone_devices"), + }; + } catch (e) { + if (e instanceof Error) throw e; + else return { status: "error", error: e as any }; + } + }, + async getCurrentMicrophoneDevice(): Promise> { try { - return { status: "ok", data: await TAURI_INVOKE("plugin:listener|get_current_microphone_device") }; -} catch (e) { - if(e instanceof Error) throw e; - else return { status: "error", error: e as any }; -} -}, -async getMicMuted() : Promise> { + return { + status: "ok", + data: await TAURI_INVOKE( + "plugin:listener|get_current_microphone_device", + ), + }; + } catch (e) { + if (e instanceof Error) throw e; + else return { status: "error", error: e as any }; + } + }, + async getMicMuted(): Promise> { try { - return { status: "ok", data: await TAURI_INVOKE("plugin:listener|get_mic_muted") }; -} catch (e) { - if(e instanceof Error) throw e; - else return { status: "error", error: e as any }; -} -}, -async setMicMuted(muted: boolean) : Promise> { + return { + status: "ok", + data: await TAURI_INVOKE("plugin:listener|get_mic_muted"), + }; + } catch (e) { + if (e instanceof Error) throw e; + else return { status: "error", error: e as any }; + } + }, + async setMicMuted(muted: boolean): Promise> { try { - return { status: "ok", data: await TAURI_INVOKE("plugin:listener|set_mic_muted", { muted }) }; -} catch (e) { - if(e instanceof Error) throw e; - else return { status: "error", error: e as any }; -} -}, -async startSession(params: SessionParams) : Promise> { + return { + status: "ok", + data: await TAURI_INVOKE("plugin:listener|set_mic_muted", { muted }), + }; + } catch (e) { + if (e instanceof Error) throw e; + else return { status: "error", error: e as any }; + } + }, + async startSession(params: SessionParams): Promise> { try { - return { status: "ok", data: await TAURI_INVOKE("plugin:listener|start_session", { params }) }; -} catch (e) { - if(e instanceof Error) throw e; - else return { status: "error", error: e as any }; -} -}, -async stopSession() : Promise> { + return { + status: "ok", + data: await TAURI_INVOKE("plugin:listener|start_session", { params }), + }; + } catch (e) { + if (e instanceof Error) throw e; + else return { status: "error", error: e as any }; + } + }, + async stopSession(): Promise> { try { - return { status: "ok", data: await TAURI_INVOKE("plugin:listener|stop_session") }; -} catch (e) { - if(e instanceof Error) throw e; - else return { status: "error", error: e as any }; -} -}, -async getState() : Promise> { + return { + status: "ok", + data: await TAURI_INVOKE("plugin:listener|stop_session"), + }; + } catch (e) { + if (e instanceof Error) throw e; + else return { status: "error", error: e as any }; + } + }, + async getState(): Promise> { try { - return { status: "ok", data: await TAURI_INVOKE("plugin:listener|get_state") }; -} catch (e) { - if(e instanceof Error) throw e; - else return { status: "error", error: e as any }; -} -}, -async isSupportedLanguagesLive(provider: string, model: string | null, languages: string[]) : Promise> { + return { + status: "ok", + data: await TAURI_INVOKE("plugin:listener|get_state"), + }; + } catch (e) { + if (e instanceof Error) throw e; + else return { status: "error", error: e as any }; + } + }, + async isSupportedLanguagesLive( + provider: string, + model: string | null, + languages: string[], + ): Promise> { try { - return { status: "ok", data: await TAURI_INVOKE("plugin:listener|is_supported_languages_live", { provider, model, languages }) }; -} catch (e) { - if(e instanceof Error) throw e; - else return { status: "error", error: e as any }; -} -}, -async suggestProvidersForLanguagesLive(languages: string[]) : Promise> { + return { + status: "ok", + data: await TAURI_INVOKE( + "plugin:listener|is_supported_languages_live", + { provider, model, languages }, + ), + }; + } catch (e) { + if (e instanceof Error) throw e; + else return { status: "error", error: e as any }; + } + }, + async suggestProvidersForLanguagesLive( + languages: string[], + ): Promise> { try { - return { status: "ok", data: await TAURI_INVOKE("plugin:listener|suggest_providers_for_languages_live", { languages }) }; -} catch (e) { - if(e instanceof Error) throw e; - else return { status: "error", error: e as any }; -} -}, -async listDocumentedLanguageCodesLive() : Promise> { + return { + status: "ok", + data: await TAURI_INVOKE( + "plugin:listener|suggest_providers_for_languages_live", + { languages }, + ), + }; + } catch (e) { + if (e instanceof Error) throw e; + else return { status: "error", error: e as any }; + } + }, + async listDocumentedLanguageCodesLive(): Promise> { try { - return { status: "ok", data: await TAURI_INVOKE("plugin:listener|list_documented_language_codes_live") }; -} catch (e) { - if(e instanceof Error) throw e; - else return { status: "error", error: e as any }; -} -} -} + return { + status: "ok", + data: await TAURI_INVOKE( + "plugin:listener|list_documented_language_codes_live", + ), + }; + } catch (e) { + if (e instanceof Error) throw e; + else return { status: "error", error: e as any }; + } + }, +}; /** user-defined events **/ - export const events = __makeEvents__<{ -sessionDataEvent: SessionDataEvent, -sessionErrorEvent: SessionErrorEvent, -sessionLifecycleEvent: SessionLifecycleEvent, -sessionProgressEvent: SessionProgressEvent + sessionDataEvent: SessionDataEvent; + sessionErrorEvent: SessionErrorEvent; + sessionLifecycleEvent: SessionLifecycleEvent; + sessionProgressEvent: SessionProgressEvent; }>({ -sessionDataEvent: "plugin:listener:session-data-event", -sessionErrorEvent: "plugin:listener:session-error-event", -sessionLifecycleEvent: "plugin:listener:session-lifecycle-event", -sessionProgressEvent: "plugin:listener:session-progress-event" -}) + sessionDataEvent: "plugin:listener:session-data-event", + sessionErrorEvent: "plugin:listener:session-error-event", + sessionLifecycleEvent: "plugin:listener:session-lifecycle-event", + sessionProgressEvent: "plugin:listener:session-progress-event", +}); /** user-defined constants **/ - - /** user-defined types **/ -export type DegradedError = { type: "authentication_failed"; provider: string } | { type: "upstream_unavailable"; message: string } | { type: "connection_timeout" } | { type: "stream_error"; message: string } -export type SessionDataEvent = { type: "audio_amplitude"; session_id: string; mic: number; speaker: number } | { type: "mic_muted"; session_id: string; value: boolean } | { type: "stream_response"; session_id: string; response: StreamResponse } -export type SessionErrorEvent = { type: "audio_error"; session_id: string; error: string; device: string | null; is_fatal: boolean } | { type: "connection_error"; session_id: string; error: string } -export type SessionLifecycleEvent = { type: "inactive"; session_id: string; error: string | null } | { type: "active"; session_id: string; error?: DegradedError | null } | { type: "finalizing"; session_id: string } -export type SessionParams = { session_id: string; languages: string[]; onboarding: boolean; record_enabled: boolean; model: string; base_url: string; api_key: string; keywords: string[] } -export type SessionProgressEvent = { type: "audio_initializing"; session_id: string } | { type: "audio_ready"; session_id: string; device: string | null } | { type: "connecting"; session_id: string } | { type: "connected"; session_id: string; adapter: string } -export type State = "active" | "inactive" | "finalizing" -export type StreamAlternatives = { transcript: string; words: StreamWord[]; confidence: number; languages?: string[] } -export type StreamChannel = { alternatives: StreamAlternatives[] } -export type StreamExtra = { started_unix_millis: number } -export type StreamMetadata = { request_id: string; model_info: StreamModelInfo; model_uuid: string; extra?: StreamExtra } -export type StreamModelInfo = { name: string; version: string; arch: string } -export type StreamResponse = { type: "Results"; start: number; duration: number; is_final: boolean; speech_final: boolean; from_finalize: boolean; channel: StreamChannel; metadata: StreamMetadata; channel_index: number[] } | { type: "Metadata"; request_id: string; created: string; duration: number; channels: number } | { type: "SpeechStarted"; channel: number[]; timestamp: number } | { type: "UtteranceEnd"; channel: number[]; last_word_end: number } | { type: "Error"; error_code: number | null; error_message: string; provider: string } -export type StreamWord = { word: string; start: number; end: number; confidence: number; speaker: number | null; punctuated_word: string | null; language: string | null } - -/** tauri-specta globals **/ - -import { - invoke as TAURI_INVOKE, - Channel as TAURI_CHANNEL, -} from "@tauri-apps/api/core"; -import * as TAURI_API_EVENT from "@tauri-apps/api/event"; -import { type WebviewWindow as __WebviewWindow__ } from "@tauri-apps/api/webviewWindow"; +export type DegradedError = + | { type: "authentication_failed"; provider: string } + | { type: "upstream_unavailable"; message: string } + | { type: "connection_timeout" } + | { type: "stream_error"; message: string }; +export type PartialWord = { + text: string; + start_ms: number; + end_ms: number; + channel: number; +}; +export type SessionDataEvent = + | { + type: "audio_amplitude"; + session_id: string; + mic: number; + speaker: number; + } + | { type: "mic_muted"; session_id: string; value: boolean } + | { type: "stream_response"; session_id: string; response: StreamResponse } + | { + type: "transcript_update"; + session_id: string; + new_final_words: TranscriptWord[]; + speaker_hints: SpeakerHint[]; + partial_words: PartialWord[]; + }; +export type SessionErrorEvent = + | { + type: "audio_error"; + session_id: string; + error: string; + device: string | null; + is_fatal: boolean; + } + | { type: "connection_error"; session_id: string; error: string }; +export type SessionLifecycleEvent = + | { type: "inactive"; session_id: string; error: string | null } + | { type: "active"; session_id: string; error?: DegradedError | null } + | { type: "finalizing"; session_id: string }; +export type SessionParams = { + session_id: string; + languages: string[]; + onboarding: boolean; + record_enabled: boolean; + model: string; + base_url: string; + api_key: string; + keywords: string[]; +}; +export type SessionProgressEvent = + | { type: "audio_initializing"; session_id: string } + | { type: "audio_ready"; session_id: string; device: string | null } + | { type: "connecting"; session_id: string } + | { type: "connected"; session_id: string; adapter: string }; +export type SpeakerHint = { word_id: string; speaker_index: number }; +export type State = "active" | "inactive" | "finalizing"; +export type StreamAlternatives = { + transcript: string; + words: StreamWord[]; + confidence: number; + languages?: string[]; +}; +export type StreamChannel = { alternatives: StreamAlternatives[] }; +export type StreamExtra = { started_unix_millis: number }; +export type StreamMetadata = { + request_id: string; + model_info: StreamModelInfo; + model_uuid: string; + extra?: StreamExtra; +}; +export type StreamModelInfo = { name: string; version: string; arch: string }; +export type StreamResponse = + | { + type: "Results"; + start: number; + duration: number; + is_final: boolean; + speech_final: boolean; + from_finalize: boolean; + channel: StreamChannel; + metadata: StreamMetadata; + channel_index: number[]; + } + | { + type: "Metadata"; + request_id: string; + created: string; + duration: number; + channels: number; + } + | { type: "SpeechStarted"; channel: number[]; timestamp: number } + | { type: "UtteranceEnd"; channel: number[]; last_word_end: number } + | { + type: "Error"; + error_code: number | null; + error_message: string; + provider: string; + }; +export type StreamWord = { + word: string; + start: number; + end: number; + confidence: number; + speaker: number | null; + punctuated_word: string | null; + language: string | null; +}; +export type TranscriptWord = { + id: string; + text: string; + start_ms: number; + end_ms: number; + channel: number; +}; type __EventObj__ = { - listen: ( - cb: TAURI_API_EVENT.EventCallback, - ) => ReturnType>; - once: ( - cb: TAURI_API_EVENT.EventCallback, - ) => ReturnType>; - emit: null extends T - ? (payload?: T) => ReturnType - : (payload: T) => ReturnType; + listen: ( + cb: TAURI_API_EVENT.EventCallback, + ) => ReturnType>; + once: ( + cb: TAURI_API_EVENT.EventCallback, + ) => ReturnType>; + emit: null extends T + ? (payload?: T) => ReturnType + : (payload: T) => ReturnType; }; export type Result = - | { status: "ok"; data: T } - | { status: "error"; error: E }; + | { status: "ok"; data: T } + | { status: "error"; error: E }; function __makeEvents__>( - mappings: Record, + mappings: Record, ) { - return new Proxy( - {} as unknown as { - [K in keyof T]: __EventObj__ & { - (handle: __WebviewWindow__): __EventObj__; - }; - }, - { - get: (_, event) => { - const name = mappings[event as keyof T]; + return new Proxy( + {} as unknown as { + [K in keyof T]: __EventObj__ & { + (handle: __WebviewWindow__): __EventObj__; + }; + }, + { + get: (_, event) => { + const name = mappings[event as keyof T]; - return new Proxy((() => {}) as any, { - apply: (_, __, [window]: [__WebviewWindow__]) => ({ - listen: (arg: any) => window.listen(name, arg), - once: (arg: any) => window.once(name, arg), - emit: (arg: any) => window.emit(name, arg), - }), - get: (_, command: keyof __EventObj__) => { - switch (command) { - case "listen": - return (arg: any) => TAURI_API_EVENT.listen(name, arg); - case "once": - return (arg: any) => TAURI_API_EVENT.once(name, arg); - case "emit": - return (arg: any) => TAURI_API_EVENT.emit(name, arg); - } - }, - }); - }, - }, - ); + return new Proxy((() => {}) as any, { + apply: (_, __, [window]: [__WebviewWindow__]) => ({ + listen: (arg: any) => window.listen(name, arg), + once: (arg: any) => window.once(name, arg), + emit: (arg: any) => window.emit(name, arg), + }), + get: (_, command: keyof __EventObj__) => { + switch (command) { + case "listen": + return (arg: any) => TAURI_API_EVENT.listen(name, arg); + case "once": + return (arg: any) => TAURI_API_EVENT.once(name, arg); + case "emit": + return (arg: any) => TAURI_API_EVENT.emit(name, arg); + } + }, + }); + }, + }, + ); } diff --git a/plugins/listener/src/actors/listener/mod.rs b/plugins/listener/src/actors/listener/mod.rs index d1a3356688..549f617706 100644 --- a/plugins/listener/src/actors/listener/mod.rs +++ b/plugins/listener/src/actors/listener/mod.rs @@ -14,6 +14,7 @@ use owhisper_interface::{ControlMessage, MixedMessage}; use super::session::session_span; use crate::{DegradedError, SessionDataEvent, SessionErrorEvent, SessionProgressEvent}; +use hypr_transcript::accumulator::TranscriptAccumulator; use adapters::spawn_rx_task; @@ -50,6 +51,7 @@ pub struct ListenerState { tx: ChannelSender, rx_task: tokio::task::JoinHandle<()>, shutdown_tx: Option>, + accumulator: TranscriptAccumulator, } pub(super) enum ChannelSender { @@ -120,6 +122,7 @@ impl Actor for ListenerActor { tx, rx_task, shutdown_tx: Some(shutdown_tx), + accumulator: TranscriptAccumulator::new(), }; Ok(state) @@ -137,6 +140,18 @@ impl Actor for ListenerActor { let _ = shutdown_tx.send(()); let _ = (&mut state.rx_task).await; } + + let flush = state.accumulator.flush(); + if !flush.new_final_words.is_empty() { + let _ = (SessionDataEvent::TranscriptUpdate { + session_id: state.args.session_id.clone(), + new_final_words: flush.new_final_words, + speaker_hints: flush.speaker_hints, + partial_words: flush.partial_words, + }) + .emit(&state.args.app); + } + Ok(()) } @@ -209,6 +224,19 @@ impl Actor for ListenerActor { crate::actors::ChannelMode::MicAndSpeaker => {} } + if let Some(update) = state.accumulator.process(&response) { + if let Err(error) = (SessionDataEvent::TranscriptUpdate { + session_id: state.args.session_id.clone(), + new_final_words: update.new_final_words, + speaker_hints: update.speaker_hints, + partial_words: update.partial_words, + }) + .emit(&state.args.app) + { + tracing::error!(?error, "transcript_update_emit_failed"); + } + } + if let Err(error) = (SessionDataEvent::StreamResponse { session_id: state.args.session_id.clone(), response: Box::new(response), diff --git a/plugins/listener/src/events.rs b/plugins/listener/src/events.rs index a8449f75eb..0b443d8807 100644 --- a/plugins/listener/src/events.rs +++ b/plugins/listener/src/events.rs @@ -1,5 +1,7 @@ use owhisper_interface::stream::StreamResponse; +use hypr_transcript::accumulator::{PartialWord, SpeakerHint, TranscriptWord}; + #[macro_export] macro_rules! common_event_derives { ($item:item) => { @@ -80,5 +82,12 @@ common_event_derives! { session_id: String, response: Box, }, + #[serde(rename = "transcript_update")] + TranscriptUpdate { + session_id: String, + new_final_words: Vec, + speaker_hints: Vec, + partial_words: Vec, + }, } } diff --git a/plugins/listener2/Cargo.toml b/plugins/listener2/Cargo.toml index fc336f6c25..a4ab66fa67 100644 --- a/plugins/listener2/Cargo.toml +++ b/plugins/listener2/Cargo.toml @@ -19,6 +19,7 @@ tauri-plugin-settings = { workspace = true } hypr-audio-utils = { workspace = true } hypr-host = { workspace = true } hypr-language = { workspace = true } +hypr-transcript = { workspace = true } owhisper-client = { workspace = true, features = ["argmax"] } owhisper-interface = { workspace = true } diff --git a/plugins/listener2/js/bindings.gen.ts b/plugins/listener2/js/bindings.gen.ts index a7d3d873b1..7f0f9a2332 100644 --- a/plugins/listener2/js/bindings.gen.ts +++ b/plugins/listener2/js/bindings.gen.ts @@ -1,152 +1,231 @@ // @ts-nocheck +/** tauri-specta globals **/ +import { + Channel as TAURI_CHANNEL, + invoke as TAURI_INVOKE, +} from "@tauri-apps/api/core"; +import * as TAURI_API_EVENT from "@tauri-apps/api/event"; +import { type WebviewWindow as __WebviewWindow__ } from "@tauri-apps/api/webviewWindow"; // This file was generated by [tauri-specta](https://github.com/oscartbeaumont/tauri-specta). Do not edit this file manually. /** user-defined commands **/ - export const commands = { -async runBatch(params: BatchParams) : Promise> { + async runBatch(params: BatchParams): Promise> { try { - return { status: "ok", data: await TAURI_INVOKE("plugin:listener2|run_batch", { params }) }; -} catch (e) { - if(e instanceof Error) throw e; - else return { status: "error", error: e as any }; -} -}, -async parseSubtitle(path: string) : Promise> { + return { + status: "ok", + data: await TAURI_INVOKE("plugin:listener2|run_batch", { params }), + }; + } catch (e) { + if (e instanceof Error) throw e; + else return { status: "error", error: e as any }; + } + }, + async parseSubtitle(path: string): Promise> { try { - return { status: "ok", data: await TAURI_INVOKE("plugin:listener2|parse_subtitle", { path }) }; -} catch (e) { - if(e instanceof Error) throw e; - else return { status: "error", error: e as any }; -} -}, -async exportToVtt(sessionId: string, words: VttWord[]) : Promise> { + return { + status: "ok", + data: await TAURI_INVOKE("plugin:listener2|parse_subtitle", { path }), + }; + } catch (e) { + if (e instanceof Error) throw e; + else return { status: "error", error: e as any }; + } + }, + async exportToVtt( + sessionId: string, + words: VttWord[], + ): Promise> { try { - return { status: "ok", data: await TAURI_INVOKE("plugin:listener2|export_to_vtt", { sessionId, words }) }; -} catch (e) { - if(e instanceof Error) throw e; - else return { status: "error", error: e as any }; -} -}, -async isSupportedLanguagesBatch(provider: string, model: string | null, languages: string[]) : Promise> { + return { + status: "ok", + data: await TAURI_INVOKE("plugin:listener2|export_to_vtt", { + sessionId, + words, + }), + }; + } catch (e) { + if (e instanceof Error) throw e; + else return { status: "error", error: e as any }; + } + }, + async isSupportedLanguagesBatch( + provider: string, + model: string | null, + languages: string[], + ): Promise> { try { - return { status: "ok", data: await TAURI_INVOKE("plugin:listener2|is_supported_languages_batch", { provider, model, languages }) }; -} catch (e) { - if(e instanceof Error) throw e; - else return { status: "error", error: e as any }; -} -}, -async suggestProvidersForLanguagesBatch(languages: string[]) : Promise> { + return { + status: "ok", + data: await TAURI_INVOKE( + "plugin:listener2|is_supported_languages_batch", + { provider, model, languages }, + ), + }; + } catch (e) { + if (e instanceof Error) throw e; + else return { status: "error", error: e as any }; + } + }, + async suggestProvidersForLanguagesBatch( + languages: string[], + ): Promise> { try { - return { status: "ok", data: await TAURI_INVOKE("plugin:listener2|suggest_providers_for_languages_batch", { languages }) }; -} catch (e) { - if(e instanceof Error) throw e; - else return { status: "error", error: e as any }; -} -}, -async listDocumentedLanguageCodesBatch() : Promise> { + return { + status: "ok", + data: await TAURI_INVOKE( + "plugin:listener2|suggest_providers_for_languages_batch", + { languages }, + ), + }; + } catch (e) { + if (e instanceof Error) throw e; + else return { status: "error", error: e as any }; + } + }, + async listDocumentedLanguageCodesBatch(): Promise> { try { - return { status: "ok", data: await TAURI_INVOKE("plugin:listener2|list_documented_language_codes_batch") }; -} catch (e) { - if(e instanceof Error) throw e; - else return { status: "error", error: e as any }; -} -} -} + return { + status: "ok", + data: await TAURI_INVOKE( + "plugin:listener2|list_documented_language_codes_batch", + ), + }; + } catch (e) { + if (e instanceof Error) throw e; + else return { status: "error", error: e as any }; + } + }, +}; /** user-defined events **/ - export const events = __makeEvents__<{ -batchEvent: BatchEvent + batchEvent: BatchEvent; }>({ -batchEvent: "plugin:listener2:batch-event" -}) + batchEvent: "plugin:listener2:batch-event", +}); /** user-defined constants **/ - - /** user-defined types **/ -export type BatchAlternatives = { transcript: string; confidence: number; words?: BatchWord[] } -export type BatchChannel = { alternatives: BatchAlternatives[] } -export type BatchEvent = { type: "batchStarted"; session_id: string } | { type: "batchResponse"; session_id: string; response: BatchResponse } | { type: "batchProgress"; session_id: string; response: StreamResponse; percentage: number } | { type: "batchFailed"; session_id: string; error: string } -export type BatchParams = { session_id: string; provider: BatchProvider; file_path: string; model?: string | null; base_url: string; api_key: string; languages?: string[]; keywords?: string[] } -export type BatchProvider = "deepgram" | "soniox" | "assemblyai" | "am" -export type BatchResponse = { metadata: JsonValue; results: BatchResults } -export type BatchResults = { channels: BatchChannel[] } -export type BatchWord = { word: string; start: number; end: number; confidence: number; speaker: number | null; punctuated_word: string | null } -export type JsonValue = null | boolean | number | string | JsonValue[] | Partial<{ [key in string]: JsonValue }> -export type StreamAlternatives = { transcript: string; words: StreamWord[]; confidence: number; languages?: string[] } -export type StreamChannel = { alternatives: StreamAlternatives[] } -export type StreamExtra = { started_unix_millis: number } -export type StreamMetadata = { request_id: string; model_info: StreamModelInfo; model_uuid: string; extra?: StreamExtra } -export type StreamModelInfo = { name: string; version: string; arch: string } -export type StreamResponse = { type: "Results"; start: number; duration: number; is_final: boolean; speech_final: boolean; from_finalize: boolean; channel: StreamChannel; metadata: StreamMetadata; channel_index: number[] } | { type: "Metadata"; request_id: string; created: string; duration: number; channels: number } | { type: "SpeechStarted"; channel: number[]; timestamp: number } | { type: "UtteranceEnd"; channel: number[]; last_word_end: number } | { type: "Error"; error_code: number | null; error_message: string; provider: string } -export type StreamWord = { word: string; start: number; end: number; confidence: number; speaker: number | null; punctuated_word: string | null; language: string | null } -export type Subtitle = { tokens: Token[] } -export type Token = { text: string; start_time: number; end_time: number; speaker: string | null } -export type VttWord = { text: string; start_ms: number; end_ms: number; speaker: string | null } - -/** tauri-specta globals **/ - -import { - invoke as TAURI_INVOKE, - Channel as TAURI_CHANNEL, -} from "@tauri-apps/api/core"; -import * as TAURI_API_EVENT from "@tauri-apps/api/event"; -import { type WebviewWindow as __WebviewWindow__ } from "@tauri-apps/api/webviewWindow"; +export type BatchAlternatives = { + transcript: string; + confidence: number; + words?: BatchWord[]; +}; +export type BatchChannel = { alternatives: BatchAlternatives[] }; +export type BatchEvent = + | { type: "batchStarted"; session_id: string } + | { type: "batchResponse"; session_id: string; response: BatchResponse } + | { + type: "batchProgress"; + session_id: string; + words: TranscriptWord[]; + speaker_hints: SpeakerHint[]; + percentage: number; + } + | { type: "batchFailed"; session_id: string; error: string }; +export type BatchParams = { + session_id: string; + provider: BatchProvider; + file_path: string; + model?: string | null; + base_url: string; + api_key: string; + languages?: string[]; + keywords?: string[]; +}; +export type BatchProvider = "deepgram" | "soniox" | "assemblyai" | "am"; +export type BatchResponse = { metadata: JsonValue; results: BatchResults }; +export type BatchResults = { channels: BatchChannel[] }; +export type BatchWord = { + word: string; + start: number; + end: number; + confidence: number; + speaker: number | null; + punctuated_word: string | null; +}; +export type JsonValue = + | null + | boolean + | number + | string + | JsonValue[] + | Partial<{ [key in string]: JsonValue }>; +export type SpeakerHint = { word_id: string; speaker_index: number }; +export type Subtitle = { tokens: Token[] }; +export type Token = { + text: string; + start_time: number; + end_time: number; + speaker: string | null; +}; +export type TranscriptWord = { + id: string; + text: string; + start_ms: number; + end_ms: number; + channel: number; +}; +export type VttWord = { + text: string; + start_ms: number; + end_ms: number; + speaker: string | null; +}; type __EventObj__ = { - listen: ( - cb: TAURI_API_EVENT.EventCallback, - ) => ReturnType>; - once: ( - cb: TAURI_API_EVENT.EventCallback, - ) => ReturnType>; - emit: null extends T - ? (payload?: T) => ReturnType - : (payload: T) => ReturnType; + listen: ( + cb: TAURI_API_EVENT.EventCallback, + ) => ReturnType>; + once: ( + cb: TAURI_API_EVENT.EventCallback, + ) => ReturnType>; + emit: null extends T + ? (payload?: T) => ReturnType + : (payload: T) => ReturnType; }; export type Result = - | { status: "ok"; data: T } - | { status: "error"; error: E }; + | { status: "ok"; data: T } + | { status: "error"; error: E }; function __makeEvents__>( - mappings: Record, + mappings: Record, ) { - return new Proxy( - {} as unknown as { - [K in keyof T]: __EventObj__ & { - (handle: __WebviewWindow__): __EventObj__; - }; - }, - { - get: (_, event) => { - const name = mappings[event as keyof T]; - - return new Proxy((() => {}) as any, { - apply: (_, __, [window]: [__WebviewWindow__]) => ({ - listen: (arg: any) => window.listen(name, arg), - once: (arg: any) => window.once(name, arg), - emit: (arg: any) => window.emit(name, arg), - }), - get: (_, command: keyof __EventObj__) => { - switch (command) { - case "listen": - return (arg: any) => TAURI_API_EVENT.listen(name, arg); - case "once": - return (arg: any) => TAURI_API_EVENT.once(name, arg); - case "emit": - return (arg: any) => TAURI_API_EVENT.emit(name, arg); - } - }, - }); - }, - }, - ); + return new Proxy( + {} as unknown as { + [K in keyof T]: __EventObj__ & { + (handle: __WebviewWindow__): __EventObj__; + }; + }, + { + get: (_, event) => { + const name = mappings[event as keyof T]; + + return new Proxy((() => {}) as any, { + apply: (_, __, [window]: [__WebviewWindow__]) => ({ + listen: (arg: any) => window.listen(name, arg), + once: (arg: any) => window.once(name, arg), + emit: (arg: any) => window.emit(name, arg), + }), + get: (_, command: keyof __EventObj__) => { + switch (command) { + case "listen": + return (arg: any) => TAURI_API_EVENT.listen(name, arg); + case "once": + return (arg: any) => TAURI_API_EVENT.once(name, arg); + case "emit": + return (arg: any) => TAURI_API_EVENT.emit(name, arg); + } + }, + }); + }, + }, + ); } diff --git a/plugins/listener2/src/batch.rs b/plugins/listener2/src/batch.rs index 50336696d8..a0c5c05cad 100644 --- a/plugins/listener2/src/batch.rs +++ b/plugins/listener2/src/batch.rs @@ -3,6 +3,7 @@ use std::sync::{Arc, Mutex}; use std::time::Duration; use futures_util::StreamExt; +use hypr_transcript::accumulator::TranscriptAccumulator; use owhisper_client::{ AdapterKind, ArgmaxAdapter, AssemblyAIAdapter, CactusAdapter, DashScopeAdapter, DeepgramAdapter, ElevenLabsAdapter, FireworksAdapter, GladiaAdapter, HyprnoteAdapter, @@ -84,19 +85,37 @@ pub struct BatchArgs { pub struct BatchState { pub app: tauri::AppHandle, pub session_id: String, + accumulator: TranscriptAccumulator, rx_task: tokio::task::JoinHandle<()>, shutdown_tx: Option>, } impl BatchState { - fn emit_streamed_response( - &self, - response: StreamResponse, + fn emit_words( + &mut self, + response: &StreamResponse, percentage: f64, ) -> Result<(), ActorProcessingErr> { - BatchEvent::BatchResponseStreamed { + if let Some(update) = self.accumulator.process(response) { + if !update.new_final_words.is_empty() { + BatchEvent::BatchTranscriptWords { + session_id: self.session_id.clone(), + words: update.new_final_words, + speaker_hints: update.speaker_hints, + percentage, + } + .emit(&self.app)?; + } + } + Ok(()) + } + + fn flush_words(&mut self, percentage: f64) -> Result<(), ActorProcessingErr> { + let update = self.accumulator.flush(); + BatchEvent::BatchTranscriptWords { session_id: self.session_id.clone(), - response, + words: update.new_final_words, + speaker_hints: update.speaker_hints, percentage, } .emit(&self.app)?; @@ -142,6 +161,7 @@ impl Actor for BatchActor { let state = BatchState { app: args.app, session_id: args.session_id, + accumulator: TranscriptAccumulator::new(), rx_task, shutdown_tx: Some(shutdown_tx), }; @@ -173,15 +193,7 @@ impl Actor for BatchActor { percentage, } => { tracing::info!("batch stream response received"); - - let is_final = matches!( - response.as_ref(), - StreamResponse::TranscriptResponse { is_final, .. } if *is_final - ); - - if is_final { - state.emit_streamed_response(*response, percentage)?; - } + state.emit_words(&response, percentage)?; } BatchMsg::StreamStartFailed(error) => { @@ -198,6 +210,7 @@ impl Actor for BatchActor { BatchMsg::StreamEnded => { tracing::info!("batch_stream_ended"); + state.flush_words(1.0)?; myself.stop(None); } } diff --git a/plugins/listener2/src/events.rs b/plugins/listener2/src/events.rs index cee053c5f6..28545d0a62 100644 --- a/plugins/listener2/src/events.rs +++ b/plugins/listener2/src/events.rs @@ -1,5 +1,5 @@ +use hypr_transcript::accumulator::{SpeakerHint, TranscriptWord}; use owhisper_interface::batch::Response as BatchResponse; -use owhisper_interface::stream::StreamResponse; #[macro_export] macro_rules! common_event_derives { @@ -20,9 +20,10 @@ common_event_derives! { response: BatchResponse, }, #[serde(rename = "batchProgress")] - BatchResponseStreamed { + BatchTranscriptWords { session_id: String, - response: StreamResponse, + words: Vec, + speaker_hints: Vec, percentage: f64, }, #[serde(rename = "batchFailed")]