- Python >= 3.9
- PyTorch
- PyTorch Lightning 2+
- Decord
Update DATA_ROOT in config.py
# Expected folder structure
|-- train
| `-- scenes
|-- dev
| `-- scenes
|-- eval
| `-- scenes
The data was preprocessed using the preprocessing scripts provided in the AVSEC Preprocessing.
To train the model in the paper, run this command:
python train.py --log_dir ./logs --batch_size 2 --lr 0.001 --gpu 1 --max_epochs 20
optional arguments:
-h, --help show this help message and exit
--batch_size 4 Batch size for training
--lr 0.001 Learning rate for training
--log_dir LOG_DIR Path to save tensorboard logs
To evaluate the model on AVSEC3 data, run:
usage: test.py [-h] --ckpt_path ./model.pth --save_root ./enhanced --model_uid avse [--dev_set False] [--eval_set True] [--cpu True]
optional arguments:
-h, --help show this help message and exit
--ckpt_path CKPT_PATH Path to model checkpoint
--save_root SAVE_ROOT Path to save enhanced audio
--model_uid MODEL_UID Folder name to save enhanced audio
--dev_set True Evaluate model on dev set
--eval_set False Evaluate model on eval set
--cpu True Evaluate on CPU (default is GPU)