Skip to content

Latest commit

 

History

History
106 lines (87 loc) · 8.06 KB

README.md

File metadata and controls

106 lines (87 loc) · 8.06 KB

Вычислительно эффективная модель для визуальной классикации произнесенных на видео слов

В данном репозитории содержится реализация статьи . До этого открытый код модели отсутстовал, поэтому мной была реализована вся архитектура и логика обучения, валидации и инференса с нуля.

Модель предсказывает логарифм вероятности принадлежности произнесенного на видео слова к одному из 500 классов, представленных в датасете LRW.

Установка зависимостей

Для того, чтобы установить среду разработки, выполните следующие шаги.

Создайте и активируйте venv

Создание

python3 -m venv venv

Активация

source venv/bin/activate

Соберите среду

pip install .

Обновите hydra

Это нужно для корректной работы

pip install hydra-core --upgrade

Поставьте gale

Он нужен для обработчика EMA

git clone https://github.com/benihime91/gale
cd gale
pip install .
cd ..

Реализованная архитектура

Архитектура реализованной модели представлена ниже. photo_2024-06-13_00-52-23

Модель состоит из блока 3D сверточной сети, отмасштабированной EfficientNetV2, энкодера трансформера и блока временной сверточной сети (TCN). Розовым на рисунке обозначена внешняя часть сети (frontend), выполняющая извлечение признакв, оранжевым - внутренняя часть сети (backend), отвечающая за обработку признаков.

Логика программного комплекса обучения и валидации модели

program_train

Для обучения используется pytorch-lightning. Для загрузки конфигураций - hydra.

Для логирования используется Weights&Biases. Создайте аккаунт, если это необходимо, либо закомментируйте в train.py все строчки, связанные с logger, если не будете использовать логирование.

Описание конфигурации проекта

  • Файл setup.py отвечает за создание среды разработки.
  • Файл train.py запускает обучение модели с конфигурацией (включая гиперпараметры) из файла scripts/configs/config.yaml.
  • Файл eval.py запускает тестирование модели с конфигурацией из файла scripts/configs/config.yaml. Не забудьте указать путь к предобученным весам модели
  • В модуле scripts содержатся все основные скрипты, отвечающие за реализацию архитектуру модели и обработку данных.
  • В модуле scripts/model содержатся файлы, реализующие архитектуру модели. В файле scripts/model/model_module.py содержится класс ModelModule, который реализует все необходимые методы для обучения, валидации и тестирования модели.
  • В скрипте scripts/model/e2e.py содержится архитеткура модели.
  • В модуле scripts/model/efficientnet_layers реализованы слои и блоки EfficientNetV2.
  • В модуле scripts/data реализованы скрипты для обработки данных и формирования датасета.
  • В модуле scripts/callbacks/ema_callback.py реализован обработчик EMA.

Запуск обучения

python train.py

Для обучения модели используется конфигурационный файл scripts/configs/config.yaml. Измените параметры в нем, если это необходимо. Также в конфигурационном файле scripts/configs/data/default.yaml укажите название вашего датасета и в файле scripts/configs/data/dataset/имя_датасета.yaml

Описание полей конфигурацинного файла

  • efficientnet_size - масштаб EfficientNetV2, может быть T, S, B
  • use_ema- использовать или нет экспоненциальное сглаживание, True или False
  • words - количество слов датасете, на котором будет обучаться модель
  • words_list - список конкретных слов из датасета, если необходимо
  • batch_size - размер пакета данных
  • epochs - количество эпох обучения
  • lr - скорость обучения
  • dropout - доля отключенных нейронов
  • weight_decay - коэффийиент при L2 регуляризации в оптимизаторе
  • warmup_epochs - количество эпох разогрева (для планировщика)
  • ema_decay - коэффициент сглаживания
  • workers - количество ядер cpu
  • seed - рандом
  • in_channels - количество входных каналов
  • gpus - количество GPU
  • name_exp - имя эксперимента (для логирования)
  • checkpoint_dir - директория, куда сохраняются веса
  • checkpoint - последний чекпоинт модели, с которого можно продолжить обучение
  • pretrained_model_path - путь к предобученной модели (к весам)

Запуск инференса

python eval.py

Для тестирования модели используется конфигурационный файл scripts/configs/config.yaml. Добавьте в него путь к предобученной модели (поле pretrained_model_path).

Сравнение с другими моделями

Модель Wacc,% Колво параметров, М
Реализация с EfficientNetV2-T 81.6 8
Реализация с EfficientNetV2-S 83.6 26
MobiVSR: A Visual Speech Recognition Solution for Mobile Devices 72.2 4.5
Deformation Flow Based Two-Stream Network for Lip Reading 84.13 7,95
Towards Practical Lipreading with Distilled and Efficient Model 85.5-87.9 28.8-36.4