-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 32a622e
Showing
11 changed files
with
1,285 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# MLP-Communicator | ||
|
||
We begin our work on the basis of [MISA](https://github.com/declare-lab/MISA) initially, so the whole code architecture is similar to it, including the data process, data loader and evaluation metrics. Thanks to their open source spirit for saving us a lot of time. | ||
|
||
``` | ||
@article{hazarika2020misa, | ||
title={MISA: Modality-Invariant and-Specific Representations for Multimodal Sentiment Analysis}, | ||
author={Hazarika, Devamanyu and Zimmermann, Roger and Poria, Soujanya}, | ||
journal={arXiv preprint arXiv:2005.03545}, | ||
year={2020} | ||
} | ||
``` | ||
|
||
## Requirements | ||
|
||
- Python 3.8 | ||
- Pytorch 1.11.0 | ||
|
||
you could run the code to build the environment. | ||
|
||
```shell | ||
pip install requirements.txt | ||
``` | ||
|
||
### Data Download | ||
|
||
- Install [CMU Multimodal SDK](https://github.com/A2Zadeh/CMU-MultimodalSDK). Ensure, you can perform ```from mmsdk import mmdatasdk```. | ||
- Option 1: Download [pre-computed splits](https://drive.google.com/drive/folders/1IBwWNH0XjPnZWaAlP1U2tIJH6Rb3noMI?usp=sharing) and place the contents inside ```datasets``` folder. | ||
- Option 2: Re-create splits by downloading data from MMSDK. For this, simply run the code as detailed next. | ||
|
||
### Running the code | ||
|
||
1. Set ```word_emb_path``` in ```config.py``` to [glove file](http://nlp.stanford.edu/data/glove.840B.300d.zip). | ||
2. Set ```sdk_dir``` to the path of CMU-MultimodalSDK. | ||
3. ```python train.py --data mosi```. Replace ```mosi``` with ```mosei``` for other datasets. | ||
|
||
The repository is updating... | ||
|
||
### Contact | ||
|
||
For any questions, please email at [zpl010720@gmail.com](zpl010720@gmail.com) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
|
||
import argparse | ||
from datetime import datetime | ||
from pathlib import Path | ||
import pprint | ||
from torch import optim | ||
import torch.nn as nn | ||
|
||
# path to a pretrained word embedding file | ||
word_emb_path = '' | ||
assert(word_emb_path is not None) | ||
|
||
project_dir = Path(__file__).resolve().parent.parent | ||
sdk_dir = project_dir.joinpath('CMU-MultimodalSDK') | ||
data_dir = project_dir.joinpath('datasets') | ||
data_dict = {'mosi': data_dir.joinpath('MOSI'), 'mosei': data_dir.joinpath( | ||
'MOSEI'), 'ur_funny': data_dir.joinpath('UR_FUNNY')} | ||
optimizer_dict = {'RMSprop': optim.RMSprop, 'Adam': optim.Adam} | ||
activation_dict = {'elu': nn.ELU, "hardshrink": nn.Hardshrink, "hardtanh": nn.Hardtanh, | ||
"leakyrelu": nn.LeakyReLU, "prelu": nn.PReLU, "relu": nn.ReLU, "rrelu": nn.RReLU, | ||
"tanh": nn.Tanh,"gelu":nn.GELU} | ||
|
||
|
||
def str2bool(v): | ||
"""string to boolean""" | ||
if v.lower() in ('yes', 'true', 't', 'y', '1'): | ||
return True | ||
elif v.lower() in ('no', 'false', 'f', 'n', '0'): | ||
return False | ||
else: | ||
raise argparse.ArgumentTypeError('Boolean value expected.') | ||
|
||
|
||
class Config(object): | ||
def __init__(self, **kwargs): | ||
"""Configuration Class: set kwargs as class attributes with setattr""" | ||
if kwargs is not None: | ||
for key, value in kwargs.items(): | ||
if key == 'optimizer': | ||
value = optimizer_dict[value] | ||
if key == 'activation': | ||
value = activation_dict[value] | ||
setattr(self, key, value) | ||
|
||
|
||
self.dataset_dir = data_dict[self.data.lower()] | ||
self.sdk_dir = sdk_dir | ||
# Glove path | ||
self.word_emb_path = word_emb_path | ||
|
||
self.data_dir = self.dataset_dir | ||
|
||
def __str__(self): | ||
"""Pretty-print configurations in alphabetical order""" | ||
config_str = 'Configurations\n' | ||
config_str += pprint.pformat(self.__dict__) | ||
return config_str | ||
|
||
|
||
def get_config(parse=True, **optional_kwargs): | ||
""" | ||
Get configurations as attributes of class | ||
1. Parse configurations with argparse. | ||
2. Create Config class initilized with parsed kwargs. | ||
3. Return Config class. | ||
""" | ||
parser = argparse.ArgumentParser() | ||
|
||
# Mode | ||
parser.add_argument('--mode', type=str, default='train') | ||
|
||
# parser.add_argument('--use_bert', type=str2bool, default=True) | ||
|
||
# Train | ||
time_now = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') | ||
parser.add_argument('--name', type=str, default=f"{time_now}") | ||
parser.add_argument('--batch_size', type=int, default=128) | ||
parser.add_argument('--n_epoch', type=int, default=100) | ||
parser.add_argument('--patience', type=int, default=6) | ||
parser.add_argument('--trials', type=int, default=3) | ||
|
||
parser.add_argument('--learning_rate', type=float, default=1e-4) | ||
parser.add_argument('--optimizer', type=str, default='Adam') | ||
|
||
parser.add_argument('--rnncell', type=str, default='lstm') | ||
parser.add_argument('--embedding_size', type=int, default=300) | ||
parser.add_argument('--hidden_size', type=int, default=128) | ||
parser.add_argument('--mlp_hidden_size', type=int, default=64) | ||
parser.add_argument('--dropout', type=float, default=0.5) | ||
parser.add_argument('--depth', type=int, default=1) | ||
|
||
# Selectin activation from 'elu', "hardshrink", "hardtanh", "leakyrelu", "prelu", "relu", "rrelu", "tanh" | ||
parser.add_argument('--activation', type=str, default='relu') | ||
|
||
parser.add_argument('--cls_weight', type=float, default=1) | ||
parser.add_argument('--polar_weight', type=float, default=0.1) | ||
parser.add_argument('--scale_weight', type=float, default=0.1) | ||
|
||
parser.add_argument('--model', type=str, | ||
default='MISA', help='one of {MISA, }') | ||
|
||
parser.add_argument('--test_duration', type=int, default=1) | ||
|
||
# Data | ||
parser.add_argument('--data', type=str, default='mosi') | ||
|
||
# Parse arguments | ||
if parse: | ||
kwargs = parser.parse_args() | ||
else: | ||
kwargs = parser.parse_known_args()[0] | ||
|
||
print(kwargs.data) | ||
if kwargs.data == "mosi": | ||
kwargs.num_classes = 1 | ||
kwargs.batch_size = 128 | ||
kwargs.depth = 1 | ||
elif kwargs.data == "mosei": | ||
kwargs.num_classes = 1 | ||
kwargs.batch_size = 64 | ||
kwargs.depth = 2 | ||
else: | ||
print("No dataset mentioned") | ||
exit() | ||
|
||
# Namespace => Dictionary | ||
kwargs = vars(kwargs) | ||
kwargs.update(optional_kwargs) | ||
|
||
return Config(**kwargs) |
Oops, something went wrong.