Skip to content

Commit

Permalink
adding pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelzbornik committed Nov 21, 2024
1 parent bfe474a commit 537ae71
Show file tree
Hide file tree
Showing 32 changed files with 156 additions and 232 deletions.
2 changes: 1 addition & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@
},
"postCreateCommand": "pip install -r requirements.txt",
"remoteUser": "root"
}
}
4 changes: 2 additions & 2 deletions .github/workflows/CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: CI
on:
push:
branches:
- main
- dev
pull_request:
branches:
- main
Expand Down Expand Up @@ -32,4 +32,4 @@ jobs:
run: docker build --cache-from=type=local,src=/tmp/.buildx-cache --cache-to=type=local,dest=/tmp/.buildx-cache -t my-app .

- name: Run tests
run: docker run my-app pytest
run: docker run my-app pytest
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.env
__pycache__
*.ipynb
*.db
*.db
16 changes: 16 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# .pre-commit-config.yaml
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0 # Use the latest tag from the repository
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: detect-private-key

# Add black for Python code formatting
- repo: https://github.com/psf/black
rev: 23.9.1 # Use the latest stable version of Black
hooks:
- id: black
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
}
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,4 @@ The models used by whisperX are stored in `root/.cache`, if you want to avoid do
- [ahmetoner/whisper-asr-webservice](https://github.com/ahmetoner/whisper-asr-webservice)
- [alexgo84/whisperx-server](https://github.com/alexgo84/whisperx-server)
- [chinaboard/whisperX-service](https://github.com/chinaboard/whisperX-service)
- [tijszwinkels/whisperX-api](https://github.com/tijszwinkels/whisperX-api)
- [tijszwinkels/whisperX-api](https://github.com/tijszwinkels/whisperX-api)
9 changes: 4 additions & 5 deletions app/audio.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from whisperx import load_audio

from tempfile import NamedTemporaryFile
import subprocess
from tempfile import NamedTemporaryFile

from .files import check_file_extension, VIDEO_EXTENSIONS

from whisperx import load_audio
from whisperx.audio import SAMPLE_RATE

from .files import VIDEO_EXTENSIONS, check_file_extension


def convert_video_to_audio(file):
"""
Expand Down
9 changes: 4 additions & 5 deletions app/db.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from functools import wraps
from sqlalchemy.exc import SQLAlchemyError
from fastapi import HTTPException

from dotenv import load_dotenv
from fastapi import HTTPException
from sqlalchemy import create_engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker

# Load environment variables from .env
load_dotenv()
Expand Down
5 changes: 3 additions & 2 deletions app/docs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from sqlalchemy import inspect
import json
import yaml
import os

import yaml
from sqlalchemy import inspect

DOCS_PATH = "app/docs"


Expand Down
2 changes: 1 addition & 1 deletion app/docs/db_schema.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Database schema
# Database schema

## Table: tasks

Expand Down
2 changes: 1 addition & 1 deletion app/docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -2074,4 +2074,4 @@
"description": "Manage tasks."
}
]
}
}
4 changes: 2 additions & 2 deletions app/docs/service_chart.dot
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ digraph FastAPI_Services {
Speech_to_text -> Background_Task [style=dashed];

Background_Task -> Database [style=dashed];

Database -> Get_All_Tasks [style=dashed];
Database -> Get_Task_Status [style=dashed];
Delete_Task -> Database [style=dashed];

Speech_to_text [label="Speech-to-Text Service\n(/speech-to-text)\n(/service/{service_name})"];

Background_Task [shape=diamond, color=blue, label="Background Task\nWhisperX Process"];
Get_All_Tasks [label="Get All Tasks\n(/task/all)"];
Get_Task_Status [label="Task Status\n(/task/{identifier})"];
Expand Down
2 changes: 1 addition & 1 deletion app/docs/service_chart.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 3 additions & 5 deletions app/files.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from tempfile import NamedTemporaryFile

import logging
import os
from tempfile import NamedTemporaryFile

from fastapi import HTTPException

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -60,9 +60,7 @@ def save_temporary_file(temporary_file, original_filename):
_, original_extension = os.path.splitext(original_filename)

# Create a temporary file with the original extension
temp_filename = NamedTemporaryFile(
suffix=original_extension, delete=False
).name
temp_filename = NamedTemporaryFile(suffix=original_extension, delete=False).name

# Write the contents of the SpooledTemporaryFile to the temporary file
with open(temp_filename, "wb") as dest:
Expand Down
18 changes: 6 additions & 12 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
from contextlib import asynccontextmanager

from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import RedirectResponse

from .models import Base
from .files import (
AUDIO_EXTENSIONS,
VIDEO_EXTENSIONS,
)
from .db import engine
from .routers import task, stt_services, stt
from .docs import generate_db_schema, save_openapi_json

from dotenv import load_dotenv

from .files import AUDIO_EXTENSIONS, VIDEO_EXTENSIONS
from .models import Base
from .routers import stt, stt_services, task

# Load environment variables from .env
load_dotenv()
Expand All @@ -23,9 +19,7 @@
@asynccontextmanager
async def lifespan(app: FastAPI):
save_openapi_json(app)
generate_db_schema(
Base.metadata.tables.values()
)
generate_db_schema(Base.metadata.tables.values())
yield


Expand Down
18 changes: 5 additions & 13 deletions app/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime
from uuid import uuid4

from sqlalchemy import Column, String, Float, JSON, Integer, DateTime
from sqlalchemy import JSON, Column, DateTime, Float, Integer, String
from sqlalchemy.orm import declarative_base

Base = declarative_base()
Expand Down Expand Up @@ -37,23 +37,15 @@ class Task(Base):
comment="Universally unique identifier for each task",
)
status = Column(String, comment="Current status of the task")
result = Column(
JSON, comment="JSON data representing the result of the task"
)
file_name = Column(
String, comment="Name of the file associated with the task"
)
result = Column(JSON, comment="JSON data representing the result of the task")
file_name = Column(String, comment="Name of the file associated with the task")
url = Column(String, comment="URL of the file associated with the task")
audio_duration = Column(Float, comment="Duration of the audio in seconds")
language = Column(
String, comment="Language of the file associated with the task"
)
language = Column(String, comment="Language of the file associated with the task")
task_type = Column(String, comment="Type/category of the task")
task_params = Column(JSON, comment="Parameters of the task")
duration = Column(Float, comment="Duration of the task execution")
error = Column(
String, comment="Error message, if any, associated with the task"
)
error = Column(String, comment="Error message, if any, associated with the task")
created_at = Column(
DateTime, default=datetime.utcnow, comment="Date and time of creation"
)
Expand Down
60 changes: 19 additions & 41 deletions app/routers/stt.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,34 @@
import logging
from fastapi import (
File,
UploadFile,
Form,
Depends,
APIRouter,
)
from fastapi import BackgroundTasks

import os
from tempfile import NamedTemporaryFile
from urllib.parse import urlparse

import requests
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, UploadFile
from sqlalchemy.orm import Session

from ..audio import get_audio_duration, process_audio_file
from ..db import get_db_session
from ..files import ALLOWED_EXTENSIONS, save_temporary_file, validate_extension
from ..schemas import (
Response,
ASROptions,
VADOptions,
WhsiperModelParams,
AlignmentParams,
ASROptions,
DiarizationParams,
Response,
SpeechToTextProcessingParams,
VADOptions,
WhsiperModelParams,
)

from sqlalchemy.orm import Session

from ..audio import (
process_audio_file,
get_audio_duration,
)

from ..files import (
save_temporary_file,
validate_extension,
ALLOWED_EXTENSIONS,
)

from ..tasks import (
add_task_to_db,
)

from ..tasks import add_task_to_db
from ..whisperx_services import process_audio_common

import requests
from tempfile import NamedTemporaryFile

from ..db import get_db_session

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

stt_router = APIRouter()


@stt_router.post("/speech-to-text", tags=["Speech-2-Text"])
async def speech_to_text(
background_tasks: BackgroundTasks,
Expand Down Expand Up @@ -121,15 +99,15 @@ async def speech_to_text_url(
# Extract filename from HTTP response headers or URL
with requests.get(url, stream=True) as response:
response.raise_for_status()

# Check for filename in Content-Disposition header
content_disposition = response.headers.get('Content-Disposition')
if content_disposition and 'filename=' in content_disposition:
filename = content_disposition.split('filename=')[1].strip('"')
content_disposition = response.headers.get("Content-Disposition")
if content_disposition and "filename=" in content_disposition:
filename = content_disposition.split("filename=")[1].strip('"')
else:
# Fall back to extracting from the URL path
filename = os.path.basename(url)

# Get the file extension
_, original_extension = os.path.splitext(filename)

Expand Down
Loading

0 comments on commit 537ae71

Please sign in to comment.