Skip to content

Official PyTorch code for training and inference pipeline for DepMamba: Progressive Fusion Mamba for Multimodal Depression Detection

Notifications You must be signed in to change notification settings

Jiaxin-Ye/DepMamba

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

22 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DepMamba Logo

Official PyTorch code for training and inference pipeline for
DepMamba: Progressive Fusion Mamba for Multimodal Depression Detection

DepMamba pipeline

version version version python mit

📰 News

  • 25/09/2024 Released the arxiv version paper.
  • 25/09/2024 Released the training code.

📕Introduction

In this work, we propose an audio-visual progressive fusion Mamba model for efficient multimodal depression detection, termed DepMamba. Specifically, the DepMamba features two core designs: hierarchical contextual modeling and progressive multimodal fusion. First, we introduce CNN and Mamba blocks to extract features from local to global scales, enriching contextual representation within long-range sequences. Second, we propose a multimodal collaborative State Space Model (SSM) that extracts intermodal and intramodal information for each modality by sharing state transition matrices. A multimodal enhanced SSM is then employed to process concatenated audio-visual features for improved modality cohesion.

DepMamba pipeline

📖 Usage:

1. Clone Repository

git clone https://github.com/Jiaxin-Ye/DepMamba.git

2. Requirements

Our code is based on Python 3.8 and CUDA 11.7. There are a few dependencies for running the code. The major libraries including Mamba and PyTorch are listed as follows:

conda create -n DepMamba -c conda-forge python=3.8
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
conda install packaging
git clone https://github.com/Dao-AILab/causal-conv1d.git 
cd causal-conv1d 
git checkout v1.1.3 
CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install .
cd ..
git clone https://github.com/state-spaces/mamba.git
cd ./mamba
git checkout v1.1.3
MAMBA_FORCE_BUILD=TRUE pip install .
pip install -r requirement.txt

If you want to record training log, you need to login your own wandb account.

wandb login

Change these lines in main.py to your own account.

wandb.init(
    project="DepMamba", entity="<your-wandb-id>", config=args, name=wandb_run_name,
)

3. Prepare Datasets

We use the D-Vlog and LMVD dataset, proposed in this paper. For the D-Vlog dataset, please fill in the form at the bottom of the dataset website, and send a request email to the author. For the LMVD dataset, please download features on the released Baidu Netdisk website or figshare.

Following D-Vlog's setup, the dataset is split into train, validation and test sets with a 7:1:2 ratio. For the LMVD without official splitting, we randomly split the LMVD with an 8:1:1 ratio and the specific division is stored in `./datasets/lmvd_labels.csv'.

Furthermore, you can run extract_lmvd_npy.py to obtain .npy features to train the model.

5. Training and Testing

Training

$ python main.py --train True --train_gender both --test_gender both --epochs 120 --batch_size 16 --learning_rate 1e-4 --model DepMamba --dataset dvlog --gpu 0

$ python main.py --train True --train_gender both --test_gender both --epochs 120 --batch_size 16 --learning_rate 1e-4 --model DepMamba --dataset lmvd --gpu 0

Testing

$ python main.py --train False --test_gender both --epochs 120 --batch_size 16 --learning_rate 1e-4 --model DepMamba --dataset dvlog --gpu 0

$ python main.py --train False --test_gender both --epochs 120 --batch_size 16 --learning_rate 1e-4 --model DepMamba --dataset lmvd --gpu 0

📖 Citation

  • If you find this project useful for your research, please cite our paper:
@article{yedepmamba,
  title={DepMamba: Progressive Fusion Mamba for Multimodal Depression Detection},
  author = {Jiaxin Ye and Junping Zhang and Hongming Shan},
  journal      = {CoRR},
  volume       = {abs/2409.15936},
  year         = {2024},
  eprinttype    = {arXiv},
  eprint       = {2409.15936},
}

🙌🏻 Acknowledgement

About

Official PyTorch code for training and inference pipeline for DepMamba: Progressive Fusion Mamba for Multimodal Depression Detection

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages