From cb80cd7e4832651fc46f5bcde97c8d7144b0f6c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Orell=20B=C3=BChler?= Date: Wed, 8 Nov 2023 18:51:03 +0100 Subject: [PATCH] fix issues, add support for url and azure blob transcription --- .env.example | 6 +++ .gitignore | 2 +- main.py | 124 +++++++++++++++++++++++++++++++++++++++-------- requirements.txt | 5 +- 4 files changed, 116 insertions(+), 21 deletions(-) create mode 100644 .env.example diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..4f8d22c --- /dev/null +++ b/.env.example @@ -0,0 +1,6 @@ +MEDIA_ROOT= +MODEL= +MODEL_DATA_DIR= +QUANTIZATION= +AZURE_BLOB_STORAGE_CONNECTION_STRING= +AZURE_BLOB_STORAGE_CONTAINER_NAME= \ No newline at end of file diff --git a/.gitignore b/.gitignore index 50d319c..e08e503 100644 --- a/.gitignore +++ b/.gitignore @@ -157,6 +157,6 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ data/ diff --git a/main.py b/main.py index 30d3df8..49b65e8 100644 --- a/main.py +++ b/main.py @@ -1,15 +1,24 @@ import os import asyncio +import uuid + import requests +import tempfile + from functools import lru_cache from pathlib import Path -from typing import Iterable, Optional +from typing import Iterable -from fastapi import FastAPI, File, Form, HTTPException, UploadFile, status, Request +from fastapi import FastAPI, Form, HTTPException, status, Request from faster_whisper import WhisperModel from faster_whisper.transcribe import Segment +from azure.storage.blob import BlobClient + from download import download_model_if_not_cached +from dotenv import load_dotenv + +load_dotenv() app = FastAPI() @@ -23,9 +32,20 @@ # "text": "Imagine the wildest idea that you've ever had, and you're curious about how it might scale to something that's a 100, a 1,000 times bigger..." # } -MODEL_DATA_DIR = "/data/cache" +MODEL_DATA_DIR = os.getenv("MODEL_DATA_DIR") MEDIA_DIR = os.getenv("MEDIA_ROOT") +AZURE_BLOB_STORAGE_CONTAINER_NAME = os.getenv('AZURE_BLOB_STORAGE_CONTAINER_NAME') +AZURE_BLOB_STORAGE_CONNECTION_STRING = os.getenv('AZURE_BLOB_STORAGE_CONNECTION_STRING') + +WHISPER_DEFAULT_SETTINGS = { + "whisper_model": os.getenv("MODEL"), + "quantization": os.getenv("QUANTIZATION"), + "task": "transcribe", + "language": "de", + "beam_size": 5, +} + @lru_cache(maxsize=1) def get_whisper_model(whisper_model: str, quantization: str) -> WhisperModel: @@ -42,7 +62,7 @@ def get_whisper_model(whisper_model: str, quantization: str) -> WhisperModel: def transcribe( - audio_path: str, whisper_model: str, quantization: str, **whisper_args + audio_path: str, whisper_model: str, quantization: str, **whisper_args ) -> Iterable[Segment]: """Transcribe the audio file using whisper""" @@ -58,18 +78,10 @@ def transcribe( return segments -WHISPER_DEFAULT_SETTINGS = { - "whisper_model": os.getenv("MODEL"), - "quantization": os.getenv("QUANTIZATION"), - "task": "transcribe", - "language": "de", - "beam_size": 5, -} - async def transcribe_post(postback_uri: str, audio_path: str): print(f"{audio_path}: Starting transcription") segments = transcribe(audio_path, **WHISPER_DEFAULT_SETTINGS) - + segment_dicts = [] for segment in segments: @@ -82,16 +94,23 @@ async def transcribe_post(postback_uri: str, audio_path: str): ) data = {"content": segment_dicts} - r = requests.post(url=postback_uri, json=data) - print(r.status_code, r.reason) + try: + print(f"{audio_path}: Posting transcription to {postback_uri}") + r = requests.post(url=postback_uri, json=data) + r.raise_for_status() + print(f"Deleting {audio_path}") + os.remove(audio_path) + except requests.exceptions.HTTPError as e: + print(e) + os.remove(audio_path) + @app.post("/v1/audio/transcriptions") async def transcriptions( - request: Request, - model: str = Form(...), - file: str = Form(...) + request: Request, + model: str = Form(...), + file: str = Form(...) ): - print(f"Received request for {file}") postback_uri = request.headers.get("LanguageServicePostbackUri") @@ -100,3 +119,70 @@ async def transcriptions( loop.create_task(transcribe_post(postback_uri, audio_path=str(file))) return "Transcription started" + + +@app.post("/v1/audio/transcriptions/azure-file") +async def transcriptions_azure_file( + request: Request, + model: str = Form(...), + content_file_url: str = Form(...), +): + print(f"Received request for azure file {content_file_url}") + + postback_uri = request.headers.get("LanguageServicePostbackUri") + + random_file_name = uuid.uuid4().hex + path = Path(f"{MEDIA_DIR}/{random_file_name}") + + try: + blob_client = BlobClient.from_connection_string( + AZURE_BLOB_STORAGE_CONNECTION_STRING, + container_name=AZURE_BLOB_STORAGE_CONTAINER_NAME, + blob_name=content_file_url + ) + with open(file=path, mode="wb") as fs: + download_stream = blob_client.download_blob() + fs.write(download_stream.readall()) + + assert model == "whisper-ch" + loop = asyncio.get_running_loop() + loop.create_task(transcribe_post(postback_uri, audio_path=str(path))) + except Exception as e: + print(e) + os.remove(path) + + return "Transcription started" + + +@app.post("/v1/audio/transcriptions/url") +async def transcriptions_url( + request: Request, + model: str = Form(...), + url: str = Form(...) +): + print(f"Received request for {url}") + postback_uri = request.headers.get("LanguageServicePostbackUri") + + random_file_name = uuid.uuid4().hex + path = Path(f"{MEDIA_DIR}/{random_file_name}") + + try: + with requests.get(url, stream=True) as r: + r.raise_for_status() + with open(path, 'wb') as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + + assert model == "whisper-ch" + loop = asyncio.get_running_loop() + loop.create_task(transcribe_post(postback_uri, audio_path=str(path))) + except Exception as e: + print(e) + os.remove(path) + + return "Transcription started" + + +@app.get("/healthz", status_code=200) +async def health() -> str: + return "OK" diff --git a/requirements.txt b/requirements.txt index ff675dc..baf0ba0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,7 @@ fastapi==0.103.0 faster-whisper==0.7.1 python-multipart==0.0.6 -uvicorn==0.23.2 \ No newline at end of file +uvicorn==0.23.2 +python-dotenv==1.0.0 +azure-storage-blob +requests~=2.31.0 \ No newline at end of file