Skip to content

Commit

Permalink
fix issues, add support for url and azure blob transcription
Browse files Browse the repository at this point in the history
  • Loading branch information
OrellBuehler committed Nov 8, 2023
1 parent 1249158 commit cb80cd7
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 21 deletions.
6 changes: 6 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
MEDIA_ROOT=
MODEL=
MODEL_DATA_DIR=
QUANTIZATION=
AZURE_BLOB_STORAGE_CONNECTION_STRING=
AZURE_BLOB_STORAGE_CONTAINER_NAME=
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,6 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/

data/
124 changes: 105 additions & 19 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
import os
import asyncio
import uuid

import requests
import tempfile

from functools import lru_cache
from pathlib import Path
from typing import Iterable, Optional
from typing import Iterable

from fastapi import FastAPI, File, Form, HTTPException, UploadFile, status, Request
from fastapi import FastAPI, Form, HTTPException, status, Request
from faster_whisper import WhisperModel
from faster_whisper.transcribe import Segment

from azure.storage.blob import BlobClient

from download import download_model_if_not_cached
from dotenv import load_dotenv

load_dotenv()

app = FastAPI()

Expand All @@ -23,9 +32,20 @@
# "text": "Imagine the wildest idea that you've ever had, and you're curious about how it might scale to something that's a 100, a 1,000 times bigger..."
# }

MODEL_DATA_DIR = "/data/cache"
MODEL_DATA_DIR = os.getenv("MODEL_DATA_DIR")
MEDIA_DIR = os.getenv("MEDIA_ROOT")

AZURE_BLOB_STORAGE_CONTAINER_NAME = os.getenv('AZURE_BLOB_STORAGE_CONTAINER_NAME')
AZURE_BLOB_STORAGE_CONNECTION_STRING = os.getenv('AZURE_BLOB_STORAGE_CONNECTION_STRING')

WHISPER_DEFAULT_SETTINGS = {
"whisper_model": os.getenv("MODEL"),
"quantization": os.getenv("QUANTIZATION"),
"task": "transcribe",
"language": "de",
"beam_size": 5,
}


@lru_cache(maxsize=1)
def get_whisper_model(whisper_model: str, quantization: str) -> WhisperModel:
Expand All @@ -42,7 +62,7 @@ def get_whisper_model(whisper_model: str, quantization: str) -> WhisperModel:


def transcribe(
audio_path: str, whisper_model: str, quantization: str, **whisper_args
audio_path: str, whisper_model: str, quantization: str, **whisper_args
) -> Iterable[Segment]:
"""Transcribe the audio file using whisper"""

Expand All @@ -58,18 +78,10 @@ def transcribe(
return segments


WHISPER_DEFAULT_SETTINGS = {
"whisper_model": os.getenv("MODEL"),
"quantization": os.getenv("QUANTIZATION"),
"task": "transcribe",
"language": "de",
"beam_size": 5,
}

async def transcribe_post(postback_uri: str, audio_path: str):
print(f"{audio_path}: Starting transcription")
segments = transcribe(audio_path, **WHISPER_DEFAULT_SETTINGS)

segment_dicts = []

for segment in segments:
Expand All @@ -82,16 +94,23 @@ async def transcribe_post(postback_uri: str, audio_path: str):
)

data = {"content": segment_dicts}
r = requests.post(url=postback_uri, json=data)
print(r.status_code, r.reason)
try:
print(f"{audio_path}: Posting transcription to {postback_uri}")
r = requests.post(url=postback_uri, json=data)
r.raise_for_status()
print(f"Deleting {audio_path}")
os.remove(audio_path)
except requests.exceptions.HTTPError as e:
print(e)
os.remove(audio_path)


@app.post("/v1/audio/transcriptions")
async def transcriptions(
request: Request,
model: str = Form(...),
file: str = Form(...)
request: Request,
model: str = Form(...),
file: str = Form(...)
):

print(f"Received request for {file}")
postback_uri = request.headers.get("LanguageServicePostbackUri")

Expand All @@ -100,3 +119,70 @@ async def transcriptions(
loop.create_task(transcribe_post(postback_uri, audio_path=str(file)))

return "Transcription started"


@app.post("/v1/audio/transcriptions/azure-file")
async def transcriptions_azure_file(
request: Request,
model: str = Form(...),
content_file_url: str = Form(...),
):
print(f"Received request for azure file {content_file_url}")

postback_uri = request.headers.get("LanguageServicePostbackUri")

random_file_name = uuid.uuid4().hex
path = Path(f"{MEDIA_DIR}/{random_file_name}")

try:
blob_client = BlobClient.from_connection_string(
AZURE_BLOB_STORAGE_CONNECTION_STRING,
container_name=AZURE_BLOB_STORAGE_CONTAINER_NAME,
blob_name=content_file_url
)
with open(file=path, mode="wb") as fs:
download_stream = blob_client.download_blob()
fs.write(download_stream.readall())

assert model == "whisper-ch"
loop = asyncio.get_running_loop()
loop.create_task(transcribe_post(postback_uri, audio_path=str(path)))
except Exception as e:
print(e)
os.remove(path)

return "Transcription started"


@app.post("/v1/audio/transcriptions/url")
async def transcriptions_url(
request: Request,
model: str = Form(...),
url: str = Form(...)
):
print(f"Received request for {url}")
postback_uri = request.headers.get("LanguageServicePostbackUri")

random_file_name = uuid.uuid4().hex
path = Path(f"{MEDIA_DIR}/{random_file_name}")

try:
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)

assert model == "whisper-ch"
loop = asyncio.get_running_loop()
loop.create_task(transcribe_post(postback_uri, audio_path=str(path)))
except Exception as e:
print(e)
os.remove(path)

return "Transcription started"


@app.get("/healthz", status_code=200)
async def health() -> str:
return "OK"
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
fastapi==0.103.0
faster-whisper==0.7.1
python-multipart==0.0.6
uvicorn==0.23.2
uvicorn==0.23.2
python-dotenv==1.0.0
azure-storage-blob
requests~=2.31.0

0 comments on commit cb80cd7

Please sign in to comment.