Official PyTorch code for training and inference pipeline for
DepMamba: Progressive Fusion Mamba for Multimodal Depression Detection
- 25/09/2024 Released the arxiv version paper.
- 25/09/2024 Released the training code.
In this work, we propose an audio-visual progressive fusion Mamba model for efficient multimodal depression detection, termed DepMamba. Specifically, the DepMamba features two core designs: hierarchical contextual modeling and progressive multimodal fusion. First, we introduce CNN and Mamba blocks to extract features from local to global scales, enriching contextual representation within long-range sequences. Second, we propose a multimodal collaborative State Space Model (SSM) that extracts intermodal and intramodal information for each modality by sharing state transition matrices. A multimodal enhanced SSM is then employed to process concatenated audio-visual features for improved modality cohesion.
git clone https://github.com/Jiaxin-Ye/DepMamba.git
Our code is based on Python 3.8 and CUDA 11.7. There are a few dependencies for running the code. The major libraries including Mamba and PyTorch are listed as follows:
conda create -n DepMamba -c conda-forge python=3.8
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
conda install packaging
git clone https://github.com/Dao-AILab/causal-conv1d.git
cd causal-conv1d
git checkout v1.1.3
CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install .
cd ..
git clone https://github.com/state-spaces/mamba.git
cd ./mamba
git checkout v1.1.3
MAMBA_FORCE_BUILD=TRUE pip install .
pip install -r requirement.txt
If you want to record training log, you need to login your own wandb
account.
wandb login
Change these lines in main.py
to your own account.
wandb.init(
project="DepMamba", entity="<your-wandb-id>", config=args, name=wandb_run_name,
)
We use the D-Vlog and LMVD dataset, proposed in this paper. For the D-Vlog dataset, please fill in the form at the bottom of the dataset website, and send a request email to the author. For the LMVD dataset, please download features on the released Baidu Netdisk website or figshare.
Following D-Vlog's setup, the dataset is split into train, validation and test sets with a 7:1:2 ratio. For the LMVD without official splitting, we randomly split the LMVD with an 8:1:1 ratio and the specific division is stored in `./datasets/lmvd_labels.csv'.
Furthermore, you can run extract_lmvd_npy.py
to obtain .npy features to train the model.
$ python main.py --train True --train_gender both --test_gender both --epochs 120 --batch_size 16 --learning_rate 1e-4 --model DepMamba --dataset dvlog --gpu 0
$ python main.py --train True --train_gender both --test_gender both --epochs 120 --batch_size 16 --learning_rate 1e-4 --model DepMamba --dataset lmvd --gpu 0
$ python main.py --train False --test_gender both --epochs 120 --batch_size 16 --learning_rate 1e-4 --model DepMamba --dataset dvlog --gpu 0
$ python main.py --train False --test_gender both --epochs 120 --batch_size 16 --learning_rate 1e-4 --model DepMamba --dataset lmvd --gpu 0
- If you find this project useful for your research, please cite our paper:
@article{yedepmamba,
title={DepMamba: Progressive Fusion Mamba for Multimodal Depression Detection},
author = {Jiaxin Ye and Junping Zhang and Hongming Shan},
journal = {CoRR},
volume = {abs/2409.15936},
year = {2024},
eprinttype = {arXiv},
eprint = {2409.15936},
}
- We acknowledge the wonderful work of Mamba and Vision Mamba.
- We borrow their implementation of Mamba and bidirectional Mamba.
- We acknowledge AllenYolk and ConMamba.
- The training pipelines are adapted from SpeechBrain.