В данном репозитории содержится реализация статьи . До этого открытый код модели отсутстовал, поэтому мной была реализована вся архитектура и логика обучения, валидации и инференса с нуля.
Модель предсказывает логарифм вероятности принадлежности произнесенного на видео слова к одному из 500 классов, представленных в датасете LRW.
Для того, чтобы установить среду разработки, выполните следующие шаги.
Создание
python3 -m venv venv
Активация
source venv/bin/activate
pip install .
Это нужно для корректной работы
pip install hydra-core --upgrade
Он нужен для обработчика EMA
git clone https://github.com/benihime91/gale
cd gale
pip install .
cd ..
Архитектура реализованной модели представлена ниже.
Модель состоит из блока 3D сверточной сети, отмасштабированной EfficientNetV2, энкодера трансформера и блока временной сверточной сети (TCN). Розовым на рисунке обозначена внешняя часть сети (frontend), выполняющая извлечение признакв, оранжевым - внутренняя часть сети (backend), отвечающая за обработку признаков.
Для обучения используется pytorch-lightning
. Для загрузки конфигураций - hydra
.
Для логирования используется Weights&Biases. Создайте аккаунт, если это необходимо, либо закомментируйте в train.py
все строчки, связанные с logger
, если не будете использовать логирование.
- Файл
setup.py
отвечает за создание среды разработки. - Файл
train.py
запускает обучение модели с конфигурацией (включая гиперпараметры) из файлаscripts/configs/config.yaml
. - Файл
eval.py
запускает тестирование модели с конфигурацией из файлаscripts/configs/config.yaml
. Не забудьте указать путь к предобученным весам модели - В модуле
scripts
содержатся все основные скрипты, отвечающие за реализацию архитектуру модели и обработку данных. - В модуле
scripts/model
содержатся файлы, реализующие архитектуру модели. В файлеscripts/model/model_module.py
содержится класс ModelModule, который реализует все необходимые методы для обучения, валидации и тестирования модели. - В скрипте
scripts/model/e2e.py
содержится архитеткура модели. - В модуле
scripts/model/efficientnet_layers
реализованы слои и блоки EfficientNetV2. - В модуле
scripts/data
реализованы скрипты для обработки данных и формирования датасета. - В модуле
scripts/callbacks/ema_callback.py
реализован обработчик EMA.
python train.py
Для обучения модели используется конфигурационный файл scripts/configs/config.yaml
. Измените параметры в нем, если это необходимо. Также в конфигурационном файле scripts/configs/data/default.yaml
укажите название вашего датасета и в файле scripts/configs/data/dataset/имя_датасета.yaml
efficientnet_size
- масштаб EfficientNetV2, может бытьT
,S
,B
use_ema
- использовать или нет экспоненциальное сглаживание,True
илиFalse
words
- количество слов датасете, на котором будет обучаться модельwords_list
- список конкретных слов из датасета, если необходимоbatch_size
- размер пакета данныхepochs
- количество эпох обученияlr
- скорость обученияdropout
- доля отключенных нейроновweight_decay
- коэффийиент при L2 регуляризации в оптимизатореwarmup_epochs
- количество эпох разогрева (для планировщика)ema_decay
- коэффициент сглаживанияworkers
- количество ядер cpuseed
- рандомin_channels
- количество входных каналовgpus
- количество GPUname_exp
- имя эксперимента (для логирования)checkpoint_dir
- директория, куда сохраняются весаcheckpoint
- последний чекпоинт модели, с которого можно продолжить обучениеpretrained_model_path
- путь к предобученной модели (к весам)
python eval.py
Для тестирования модели используется конфигурационный файл scripts/configs/config.yaml
. Добавьте в него путь к предобученной модели (поле pretrained_model_path
).
Модель | Wacc,% | Колво параметров, М |
---|---|---|
Реализация с EfficientNetV2-T | 81.6 | 8 |
Реализация с EfficientNetV2-S | 83.6 | 26 |
MobiVSR: A Visual Speech Recognition Solution for Mobile Devices | 72.2 | 4.5 |
Deformation Flow Based Two-Stream Network for Lip Reading | 84.13 | 7,95 |
Towards Practical Lipreading with Distilled and Efficient Model | 85.5-87.9 | 28.8-36.4 |