Skip to content

Commit

Permalink
Merge pull request #59 from pavelzbornik:dev
Browse files Browse the repository at this point in the history
Add support for distilled and custom models in README and schemas
  • Loading branch information
pavelzbornik authored Jan 10, 2025
2 parents ea87c5b + 4345221 commit 685187a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ WhisperX supports these model sizes:
- `small`, `small.en`
- `medium`, `medium.en`
- `large`, `large-v1`, `large-v2`, `large-v3`, `large-v3-turbo`
- Distilled models: `distil-large-v2`, `distil-medium.en`, `distil-small.en`, `distil-large-v3`
- Custom models: [`nyrahealth/faster_CrisperWhisper`](https://github.com/nyrahealth/CrisperWhisper)

Set default model in `.env` using `WHISPER_MODEL=` (default: tiny)

Expand Down
5 changes: 5 additions & 0 deletions app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ class WhisperModel(str, Enum):
large_v2 = "large-v2"
large_v3 = "large-v3"
large_v3_turbo = "large-v3-turbo"
distil_large_v2 = "distil-large-v2"
distil_medium_en = "distil-medium.en"
distil_small_en = "distil-small.en"
distil_large_v3 = "distil-large-v3"
faster_crisper_whisper = "nyrahealth/faster_CrisperWhisper"


class Device(str, Enum):
Expand Down
8 changes: 5 additions & 3 deletions app/whisperx_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ def transcribe_with_whisper(

logger.debug(
"Loading model with config - model: %s, device: %s, compute_type: %s, threads: %d, task: %s, language: %s",
model,
model.value,
device,
compute_type,
faster_whisper_threads,
task,
language,
)
model = load_model(
model,
model.value,
device,
device_index=device_index,
compute_type=compute_type,
Expand All @@ -90,7 +90,9 @@ def transcribe_with_whisper(
threads=faster_whisper_threads,
)
logger.debug("Transcription model loaded successfully")
result = model.transcribe(audio=audio, batch_size=batch_size, chunk_size=chunk_size, language=language)
result = model.transcribe(
audio=audio, batch_size=batch_size, chunk_size=chunk_size, language=language
)

# Log GPU memory before cleanup
if torch.cuda.is_available():
Expand Down

0 comments on commit 685187a

Please sign in to comment.