Skip to content

Commit 60b1dbe

Browse files
Fedir Zadniprovskyifedirz
authored andcommitted
feat: handle srt and vtt response formats
1 parent 166eb43 commit 60b1dbe

File tree

9 files changed

+141
-46
lines changed

9 files changed

+141
-46
lines changed

faster_whisper_server/config.py

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,35 +15,8 @@ class ResponseFormat(enum.StrEnum):
1515
TEXT = "text"
1616
JSON = "json"
1717
VERBOSE_JSON = "verbose_json"
18-
# NOTE: While inspecting outputs of these formats with `curl`, I noticed there's one or two "\n" inserted at the end of the response. # noqa: E501
19-
20-
# VTT = "vtt" # TODO
21-
# 1
22-
# 00:00:00,000 --> 00:00:09,220
23-
# In his video on Large Language Models or LLMs, OpenAI co-founder and YouTuber Andrej Karpathy
24-
#
25-
# 2
26-
# 00:00:09,220 --> 00:00:12,280
27-
# likened LLMs to operating systems.
28-
#
29-
# 3
30-
# 00:00:12,280 --> 00:00:13,280
31-
# Karpathy said,
32-
#
33-
# SRT = "srt" # TODO
34-
# WEBVTT
35-
#
36-
# 00:00:00.000 --> 00:00:09.220
37-
# In his video on Large Language Models or LLMs, OpenAI co-founder and YouTuber Andrej Karpathy
38-
#
39-
# 00:00:09.220 --> 00:00:12.280
40-
# likened LLMs to operating systems.
41-
#
42-
# 00:00:12.280 --> 00:00:13.280
43-
# Karpathy said,
44-
#
45-
# 00:00:13.280 --> 00:00:19.799
46-
# I see a lot of equivalence between this new LLM OS and operating systems of today.
18+
SRT = "srt"
19+
VTT = "vtt"
4720

4821

4922
class Device(enum.StrEnum):

faster_whisper_server/core.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,62 @@ def segments_to_text(segments: Iterable[Segment]) -> str:
172172
return "".join(segment.text for segment in segments).strip()
173173

174174

175+
def srt_format_timestamp(ts: float) -> str:
176+
hours = ts // 3600
177+
minutes = (ts % 3600) // 60
178+
seconds = ts % 60
179+
milliseconds = (ts * 1000) % 1000
180+
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
181+
182+
183+
def test_srt_format_timestamp() -> None:
184+
assert srt_format_timestamp(0.0) == "00:00:00,000"
185+
assert srt_format_timestamp(1.0) == "00:00:01,000"
186+
assert srt_format_timestamp(1.234) == "00:00:01,234"
187+
assert srt_format_timestamp(60.0) == "00:01:00,000"
188+
assert srt_format_timestamp(61.0) == "00:01:01,000"
189+
assert srt_format_timestamp(61.234) == "00:01:01,234"
190+
assert srt_format_timestamp(3600.0) == "01:00:00,000"
191+
assert srt_format_timestamp(3601.0) == "01:00:01,000"
192+
assert srt_format_timestamp(3601.234) == "01:00:01,234"
193+
assert srt_format_timestamp(23423.4234) == "06:30:23,423"
194+
195+
196+
def vtt_format_timestamp(ts: float) -> str:
197+
hours = ts // 3600
198+
minutes = (ts % 3600) // 60
199+
seconds = ts % 60
200+
milliseconds = (ts * 1000) % 1000
201+
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
202+
203+
204+
def test_vtt_format_timestamp() -> None:
205+
assert vtt_format_timestamp(0.0) == "00:00:00.000"
206+
assert vtt_format_timestamp(1.0) == "00:00:01.000"
207+
assert vtt_format_timestamp(1.234) == "00:00:01.234"
208+
assert vtt_format_timestamp(60.0) == "00:01:00.000"
209+
assert vtt_format_timestamp(61.0) == "00:01:01.000"
210+
assert vtt_format_timestamp(61.234) == "00:01:01.234"
211+
assert vtt_format_timestamp(3600.0) == "01:00:00.000"
212+
assert vtt_format_timestamp(3601.0) == "01:00:01.000"
213+
assert vtt_format_timestamp(3601.234) == "01:00:01.234"
214+
assert vtt_format_timestamp(23423.4234) == "06:30:23.423"
215+
216+
217+
def segments_to_vtt(segment: Segment, i: int) -> str:
218+
start = segment.start if i > 0 else 0.0
219+
result = f"{vtt_format_timestamp(start)} --> {vtt_format_timestamp(segment.end)}\n{segment.text}\n\n"
220+
221+
if i == 0:
222+
return f"WEBVTT\n\n{result}"
223+
else:
224+
return result
225+
226+
227+
def segments_to_srt(segment: Segment, i: int) -> str:
228+
return f"{i + 1}\n{srt_format_timestamp(segment.start)} --> {srt_format_timestamp(segment.end)}\n{segment.text}\n\n"
229+
230+
175231
def canonicalize_word(text: str) -> str:
176232
text = text.lower()
177233
# Remove non-alphabetic characters using regular expression

faster_whisper_server/main.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
Task,
3434
config,
3535
)
36-
from faster_whisper_server.core import Segment, segments_to_text
36+
from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
3737
from faster_whisper_server.logger import logger
3838
from faster_whisper_server.server_models import (
3939
ModelListResponse,
@@ -154,14 +154,28 @@ def segments_to_response(
154154
segments: Iterable[Segment],
155155
transcription_info: TranscriptionInfo,
156156
response_format: ResponseFormat,
157-
) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse:
157+
) -> Response:
158158
segments = list(segments)
159159
if response_format == ResponseFormat.TEXT: # noqa: RET503
160-
return segments_to_text(segments)
160+
return Response(segments_to_text(segments), media_type="text/plain")
161161
elif response_format == ResponseFormat.JSON:
162-
return TranscriptionJsonResponse.from_segments(segments)
162+
return Response(
163+
TranscriptionJsonResponse.from_segments(segments).model_dump_json(),
164+
media_type="application/json",
165+
)
163166
elif response_format == ResponseFormat.VERBOSE_JSON:
164-
return TranscriptionVerboseJsonResponse.from_segments(segments, transcription_info)
167+
return Response(
168+
TranscriptionVerboseJsonResponse.from_segments(segments, transcription_info).model_dump_json(),
169+
media_type="application/json",
170+
)
171+
elif response_format == ResponseFormat.VTT:
172+
return Response(
173+
"".join(segments_to_vtt(segment, i) for i, segment in enumerate(segments)), media_type="text/vtt"
174+
)
175+
elif response_format == ResponseFormat.SRT:
176+
return Response(
177+
"".join(segments_to_srt(segment, i) for i, segment in enumerate(segments)), media_type="text/plain"
178+
)
165179

166180

167181
def format_as_sse(data: str) -> str:
@@ -174,13 +188,17 @@ def segments_to_streaming_response(
174188
response_format: ResponseFormat,
175189
) -> StreamingResponse:
176190
def segment_responses() -> Generator[str, None, None]:
177-
for segment in segments:
191+
for i, segment in enumerate(segments):
178192
if response_format == ResponseFormat.TEXT:
179193
data = segment.text
180194
elif response_format == ResponseFormat.JSON:
181195
data = TranscriptionJsonResponse.from_segments([segment]).model_dump_json()
182196
elif response_format == ResponseFormat.VERBOSE_JSON:
183197
data = TranscriptionVerboseJsonResponse.from_segment(segment, transcription_info).model_dump_json()
198+
elif response_format == ResponseFormat.VTT:
199+
data = segments_to_vtt(segment, i)
200+
elif response_format == ResponseFormat.SRT:
201+
data = segments_to_srt(segment, i)
184202
yield format_as_sse(data)
185203

186204
return StreamingResponse(segment_responses(), media_type="text/event-stream")
@@ -211,7 +229,7 @@ def translate_file(
211229
response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
212230
temperature: Annotated[float, Form()] = 0.0,
213231
stream: Annotated[bool, Form()] = False,
214-
) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse | StreamingResponse:
232+
) -> Response | StreamingResponse:
215233
whisper = load_model(model)
216234
segments, transcription_info = whisper.transcribe(
217235
file.file,
@@ -247,7 +265,7 @@ def transcribe_file(
247265
] = ["segment"],
248266
stream: Annotated[bool, Form()] = False,
249267
hotwords: Annotated[str | None, Form()] = None,
250-
) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse | StreamingResponse:
268+
) -> Response | StreamingResponse:
251269
whisper = load_model(model)
252270
segments, transcription_info = whisper.transcribe(
253271
file.file,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ dependencies = [
1818
]
1919

2020
[project.optional-dependencies]
21-
dev = ["ruff==0.5.3", "pytest", "basedpyright==1.13.0", "pytest-xdist"]
21+
dev = ["ruff==0.5.3", "pytest", "webvtt-py", "srt", "basedpyright==1.13.0", "pytest-xdist"]
2222

2323
other = ["youtube-dl @ git+https://github.com/ytdl-org/youtube-dl.git@37cea84f775129ad715b9bcd617251c831fcc980", "aider-chat==0.39.0"]
2424

requirements-all.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ scipy==1.13.1
496496
# via aider-chat
497497
semantic-version==2.10.0
498498
# via gradio
499-
setuptools==71.0.3
499+
setuptools==71.0.4
500500
# via ctranslate2
501501
shellingham==1.5.4
502502
# via typer
@@ -524,11 +524,13 @@ soupsieve==2.5
524524
# via
525525
# aider-chat
526526
# beautifulsoup4
527+
srt==3.5.3
528+
# via faster-whisper-server (pyproject.toml)
527529
starlette==0.37.2
528530
# via fastapi
529531
streamlit==1.35.0
530532
# via aider-chat
531-
sympy==1.13.0
533+
sympy==1.13.1
532534
# via onnxruntime
533535
tenacity==8.3.0
534536
# via
@@ -623,6 +625,8 @@ websockets==11.0.3
623625
# via
624626
# gradio-client
625627
# uvicorn
628+
webvtt-py==0.5.1
629+
# via faster-whisper-server (pyproject.toml)
626630
yarl==1.9.4
627631
# via
628632
# aider-chat

requirements-dev.txt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ numpy==1.26.4
146146
# pandas
147147
onnxruntime==1.18.1
148148
# via faster-whisper
149-
openai==1.35.15
149+
openai==1.36.0
150150
# via faster-whisper-server (pyproject.toml)
151151
orjson==3.10.6
152152
# via gradio
@@ -235,7 +235,7 @@ ruff==0.5.3
235235
# gradio
236236
semantic-version==2.10.0
237237
# via gradio
238-
setuptools==71.0.3
238+
setuptools==71.0.4
239239
# via ctranslate2
240240
shellingham==1.5.4
241241
# via typer
@@ -248,9 +248,11 @@ sniffio==1.3.1
248248
# openai
249249
soundfile==0.12.1
250250
# via faster-whisper-server (pyproject.toml)
251+
srt==3.5.3
252+
# via faster-whisper-server (pyproject.toml)
251253
starlette==0.37.2
252254
# via fastapi
253-
sympy==1.13.0
255+
sympy==1.13.1
254256
# via onnxruntime
255257
tokenizers==0.19.1
256258
# via faster-whisper
@@ -295,3 +297,5 @@ websockets==11.0.3
295297
# via
296298
# gradio-client
297299
# uvicorn
300+
webvtt-py==0.5.1
301+
# via faster-whisper-server (pyproject.toml)

requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ numpy==1.26.4
138138
# pandas
139139
onnxruntime==1.18.1
140140
# via faster-whisper
141-
openai==1.35.15
141+
openai==1.36.0
142142
# via faster-whisper-server (pyproject.toml)
143143
orjson==3.10.6
144144
# via gradio
@@ -216,7 +216,7 @@ ruff==0.5.3
216216
# via gradio
217217
semantic-version==2.10.0
218218
# via gradio
219-
setuptools==71.0.3
219+
setuptools==71.0.4
220220
# via ctranslate2
221221
shellingham==1.5.4
222222
# via typer
@@ -231,7 +231,7 @@ soundfile==0.12.1
231231
# via faster-whisper-server (pyproject.toml)
232232
starlette==0.37.2
233233
# via fastapi
234-
sympy==1.13.0
234+
sympy==1.13.1
235235
# via onnxruntime
236236
tokenizers==0.19.1
237237
# via faster-whisper

tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from collections.abc import Generator
22
import logging
3+
import os
34

45
from fastapi.testclient import TestClient
56
from openai import OpenAI
67
import pytest
78

9+
os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
810
from faster_whisper_server.main import app
911

1012
disable_loggers = ["multipart.multipart", "faster_whisper"]

tests/sse_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from fastapi.testclient import TestClient
55
from httpx_sse import connect_sse
66
import pytest
7+
import srt
8+
import webvtt
9+
import webvtt.vtt
710

811
from faster_whisper_server.server_models import (
912
TranscriptionJsonResponse,
@@ -61,3 +64,38 @@ def test_streaming_transcription_verbose_json(client: TestClient, file_path: str
6164
with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
6265
for event in event_source.iter_sse():
6366
TranscriptionVerboseJsonResponse(**json.loads(event.data))
67+
68+
69+
def test_transcription_vtt(client: TestClient) -> None:
70+
with open("audio.wav", "rb") as f:
71+
data = f.read()
72+
kwargs = {
73+
"files": {"file": ("audio.wav", data, "audio/wav")},
74+
"data": {"response_format": "vtt", "stream": False},
75+
}
76+
response = client.post("/v1/audio/transcriptions", **kwargs)
77+
assert response.status_code == 200
78+
assert response.headers["content-type"] == "text/vtt; charset=utf-8"
79+
text = response.text
80+
webvtt.from_string(text)
81+
text = text.replace("WEBVTT", "YO")
82+
with pytest.raises(webvtt.vtt.MalformedFileError):
83+
webvtt.from_string(text)
84+
85+
86+
def test_transcription_srt(client: TestClient) -> None:
87+
with open("audio.wav", "rb") as f:
88+
data = f.read()
89+
kwargs = {
90+
"files": {"file": ("audio.wav", data, "audio/wav")},
91+
"data": {"response_format": "srt", "stream": False},
92+
}
93+
response = client.post("/v1/audio/transcriptions", **kwargs)
94+
assert response.status_code == 200
95+
assert "text/plain" in response.headers["content-type"]
96+
97+
text = response.text
98+
list(srt.parse(text))
99+
text = text.replace("1", "YO")
100+
with pytest.raises(srt.SRTParseError):
101+
list(srt.parse(text))

0 commit comments

Comments
 (0)