Skip to content

Commit 55162ec

Browse files
committed
add some notes
1 parent 32a622e commit 55162ec

File tree

7 files changed

+41
-63
lines changed

7 files changed

+41
-63
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ We begin our work on the basis of [MISA](https://github.com/declare-lab/MISA) in
1616
- Python 3.8
1717
- Pytorch 1.11.0
1818

19-
you could run the code to build the environment.
19+
you could run the following command to build the environment.
2020

2121
```shell
2222
pip install requirements.txt

config.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import argparse
32
from datetime import datetime
43
from pathlib import Path
@@ -10,7 +9,7 @@
109
word_emb_path = ''
1110
assert(word_emb_path is not None)
1211

13-
project_dir = Path(__file__).resolve().parent.parent
12+
project_dir = Path(__file__).resolve().parent
1413
sdk_dir = project_dir.joinpath('CMU-MultimodalSDK')
1514
data_dir = project_dir.joinpath('datasets')
1615
data_dict = {'mosi': data_dir.joinpath('MOSI'), 'mosei': data_dir.joinpath(
@@ -69,11 +68,9 @@ def get_config(parse=True, **optional_kwargs):
6968
# Mode
7069
parser.add_argument('--mode', type=str, default='train')
7170

72-
# parser.add_argument('--use_bert', type=str2bool, default=True)
73-
7471
# Train
7572
time_now = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
76-
parser.add_argument('--name', type=str, default=f"{time_now}")
73+
parser.add_argument('--name', type=str, default=f"{time_now}") # saved-model's name
7774
parser.add_argument('--batch_size', type=int, default=128)
7875
parser.add_argument('--n_epoch', type=int, default=100)
7976
parser.add_argument('--patience', type=int, default=6)
@@ -82,16 +79,17 @@ def get_config(parse=True, **optional_kwargs):
8279
parser.add_argument('--learning_rate', type=float, default=1e-4)
8380
parser.add_argument('--optimizer', type=str, default='Adam')
8481

85-
parser.add_argument('--rnncell', type=str, default='lstm')
86-
parser.add_argument('--embedding_size', type=int, default=300)
87-
parser.add_argument('--hidden_size', type=int, default=128)
88-
parser.add_argument('--mlp_hidden_size', type=int, default=64)
82+
parser.add_argument('--rnncell', type=str, default='lstm') # lstm or GRU
83+
parser.add_argument('--embedding_size', type=int, default=300) # embedding size in bert
84+
parser.add_argument('--hidden_size', type=int, default=128) # modality embedding size
85+
parser.add_argument('--mlp_hidden_size', type=int, default=64) # mlp-communicator hidden size
8986
parser.add_argument('--dropout', type=float, default=0.5)
90-
parser.add_argument('--depth', type=int, default=1)
87+
parser.add_argument('--depth', type=int, default=1) # mlp-communicator depth number
9188

9289
# Selectin activation from 'elu', "hardshrink", "hardtanh", "leakyrelu", "prelu", "relu", "rrelu", "tanh"
9390
parser.add_argument('--activation', type=str, default='relu')
9491

92+
# three loss weights
9593
parser.add_argument('--cls_weight', type=float, default=1)
9694
parser.add_argument('--polar_weight', type=float, default=0.1)
9795
parser.add_argument('--scale_weight', type=float, default=0.1)
@@ -102,7 +100,7 @@ def get_config(parse=True, **optional_kwargs):
102100
parser.add_argument('--test_duration', type=int, default=1)
103101

104102
# Data
105-
parser.add_argument('--data', type=str, default='mosi')
103+
parser.add_argument('--data', type=str, default='mosi') # mosi or mosei
106104

107105
# Parse arguments
108106
if parse:

create_dataset.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,23 @@
55
import numpy as np
66
from tqdm import tqdm_notebook
77
from collections import defaultdict
8-
from mmsdk import mmdatasdk as md
8+
9+
try:
10+
from mmsdk import mmdatasdk as md
11+
except:
12+
print("please install mmsdk first(https://github.com/A2Zadeh/CMU-MultimodalSDK)")
913
from subprocess import check_call
1014
import torch
1115

12-
16+
"""
17+
The method of data process is used in MISA, thanks to:
18+
@article{hazarika2020misa,
19+
title={MISA: Modality-Invariant and-Specific Representations for Multimodal Sentiment Analysis},
20+
author={Hazarika, Devamanyu and Zimmermann, Roger and Poria, Soujanya},
21+
journal={arXiv preprint arXiv:2005.03545},
22+
year={2020}
23+
}
24+
"""
1325
def to_pickle(obj, path):
1426
with open(path, 'wb') as f:
1527
pickle.dump(obj, f)
@@ -31,6 +43,7 @@ def return_unk():
3143
return UNK
3244

3345

46+
# transfer word into glove embedding
3447
def load_emb(w2i, path_to_embedding, embedding_size=300, embedding_vocab=2196017, init_emb=None):
3548
if init_emb is None:
3649
emb_mat = np.random.randn(len(w2i), embedding_size)
@@ -52,9 +65,8 @@ def load_emb(w2i, path_to_embedding, embedding_size=300, embedding_vocab=2196017
5265

5366
class MOSI:
5467
def __init__(self, config):
55-
5668
if config.sdk_dir is None:
57-
print("SDK path is not specified! Please specify first in constants/paths.py")
69+
print("SDK path is not specified! Please specify first in config.py")
5870
exit(0)
5971
else:
6072
sys.path.append(str(config.sdk_dir))
@@ -104,7 +116,6 @@ def __init__(self, config):
104116
]
105117

106118
recipe = {feat: os.path.join(DATA_PATH, feat) + '.csd' for feat in features}
107-
print(recipe)
108119
dataset = md.mmdataset(recipe)
109120

110121
# we define a simple averaging function that does not depend on intervals
@@ -229,7 +240,7 @@ class MOSEI:
229240
def __init__(self, config):
230241

231242
if config.sdk_dir is None:
232-
print("SDK path is not specified! Please specify first in constants/paths.py")
243+
print("SDK path is not specified! Please specify first in config.py")
233244
exit(0)
234245
else:
235246
sys.path.append(str(config.sdk_dir))
@@ -243,9 +254,7 @@ def __init__(self, config):
243254
self.dev = load_pickle(DATA_PATH + '/dev.pkl')
244255
self.test = load_pickle(DATA_PATH + '/test.pkl')
245256
self.pretrained_emb, self.word2id = torch.load(CACHE_PATH)
246-
247257
except:
248-
249258
# create folders for storing the data
250259
if not os.path.exists(DATA_PATH):
251260
check_call(' '.join(['mkdir', '-p', DATA_PATH]), shell=True)
@@ -403,4 +412,3 @@ def get_data(self, mode):
403412
else:
404413
print("Mode is not set properly (train/dev/test)")
405414
exit()
406-

data_loader.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@
77

88
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
99

10-
1110
class MSADataset(Dataset):
1211
def __init__(self, config):
13-
14-
## Fetch dataset
12+
# Fetch dataset
1513
if "mosi" in str(config.data_dir).lower():
1614
dataset = MOSI(config)
1715
elif "mosei" in str(config.data_dir).lower():
@@ -22,7 +20,6 @@ def __init__(self, config):
2220

2321
self.data, self.word2id, self.pretrained_emb = dataset.get_data(config.mode)
2422
self.len = len(self.data)
25-
2623
self.label = np.abs(np.array(self.data)[:, 1])
2724

2825
config.visual_size = self.data[0][0][1].shape[1]
@@ -40,9 +37,7 @@ def __len__(self):
4037

4138
def get_loader(config, shuffle=True):
4239
"""Load DataLoader of given DialogDataset"""
43-
4440
dataset = MSADataset(config)
45-
4641
print(config.mode)
4742
config.data_len = len(dataset)
4843

@@ -51,9 +46,7 @@ def collate_fn(batch):
5146
Collate functions assume batch = [Dataset[i] for i in index_set]
5247
'''
5348
# for later use we sort the batch in descending order of length
54-
5549
batch = sorted(batch, key=lambda x: x[0][0].shape[0], reverse=True)
56-
5750
# get the data out of the batch - use pad sequence util functions from PyTorch to pad things
5851

5952
labels = torch.cat([torch.from_numpy(sample[1]) for sample in batch], dim=0)

new_models.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from transformers import BertModel, BertConfig
55
from einops.layers.torch import Rearrange
66

7-
87
# let's define a simple model that can deal with multimodal variable length sequence
98
class MISA(nn.Module):
109
def __init__(self, config):
@@ -13,7 +12,6 @@ def __init__(self, config):
1312
self.text_size = config.embedding_size
1413
self.visual_size = config.visual_size
1514
self.acoustic_size = config.acoustic_size
16-
1715
self.input_sizes = input_sizes = [self.text_size, self.visual_size, self.acoustic_size]
1816
self.hidden_sizes = hidden_sizes = [int(self.text_size), int(self.visual_size), int(self.acoustic_size)]
1917
self.output_size = output_size = config.num_classes
@@ -65,8 +63,8 @@ def __init__(self, config):
6563
self.shared1.add_module('shared_activation_1', nn.Sigmoid())
6664

6765
self.shared2 = nn.Sequential()
68-
self.shared2.add_module('shared_1', nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size))
69-
self.shared2.add_module('shared_activation_1', nn.Sigmoid())
66+
self.shared2.add_module('shared_2', nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size))
67+
self.shared2.add_module('shared_activation_2', nn.Sigmoid())
7068

7169
self.fusion = nn.Sequential()
7270
self.fusion.add_module('fusion_layer_1', nn.Linear(in_features=self.config.hidden_size * 2,

solver.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,9 @@
1111
torch.manual_seed(123)
1212
torch.cuda.manual_seed_all(123)
1313

14-
1514
class Solver(object):
1615
def __init__(self, train_config, train_data_loader, dev_data_loader, test_data_loader,
1716
is_train=True, model=None):
18-
1917
self.scale_criterion = corr_loss()
2018
self.polar_criterion = cos_loss()
2119
self.criterion = nn.MSELoss(reduction="mean")
@@ -28,7 +26,7 @@ def __init__(self, train_config, train_data_loader, dev_data_loader, test_data_l
2826

2927
def build(self, cuda=True):
3028
if self.model is None:
31-
self.model = getattr(new_models, self.train_config.model)(self.train_config)
29+
self.model = getattr(new_models, self.train_config.model)(self.train_config) # init the model
3230

3331
# Final list
3432
for name, param in self.model.named_parameters():
@@ -41,11 +39,9 @@ def build(self, cuda=True):
4139
elif self.train_config.data == "ur_funny":
4240
if "bert" in name:
4341
param.requires_grad = False
44-
4542
if 'weight_hh' in name:
4643
nn.init.orthogonal_(param)
4744
# print('\t' + name, param.requires_grad)
48-
4945
if torch.cuda.is_available() and cuda:
5046
self.model.cuda()
5147

@@ -55,11 +51,16 @@ def build(self, cuda=True):
5551
lr=self.train_config.learning_rate)
5652

5753
def model_input2output(self, batch):
54+
"""
55+
get output from model input
56+
:param batch: batch
57+
:return: y_tilde: model predict output
58+
y: true label
59+
"""
5860
self.model.zero_grad()
5961

6062
v, a, y, l, bert_sent, bert_sent_type, bert_sent_mask = batch
6163

62-
6364
v = to_gpu(v)
6465
a = to_gpu(a)
6566
y = to_gpu(y)
@@ -73,6 +74,9 @@ def model_input2output(self, batch):
7374
return y_tilde, y
7475

7576
def loss_function(self, y_tilde, y):
77+
"""
78+
total_loss = w1 * cls_loss + w_2 * polar_loss + w_3 * scale_loss
79+
"""
7680
polar_loss = self.polar_criterion(self.model.polar_vector, y, y_tilde)
7781
scale_loss = self.scale_criterion(self.model.scale, y)
7882
cls_loss = self.criterion(y_tilde, y)
@@ -89,16 +93,14 @@ def train(self):
8993
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.5)
9094

9195
train_losses = []
92-
9396
for e in range(self.train_config.n_epoch):
9497
self.model.train()
9598
train_loss = []
9699
for batch in self.train_data_loader:
97100
y_tilde, y = self.model_input2output(batch)
98-
99101
loss = self.loss_function(y_tilde, y)
100-
loss.backward()
101102

103+
loss.backward()
102104
self.optimizer.step()
103105
train_loss.append(loss.item())
104106
train_losses.append(train_loss)
@@ -136,18 +138,7 @@ def train(self):
136138
if num_trials <= 0:
137139
print("Running out of patience, early stopping.")
138140
break
139-
140141
self.eval(mode="test", to_print=True)
141-
# print(f"best test accracy: {best_test_acc}")
142-
#
143-
# # data = {"train_loss": self.train_loss, "train_task": self.task_loss, "train_polar":self.polar_loss, "train_scale":self.scale_loss,
144-
# # "val_loss": self.val_loss, "val_task": self.task_loss_val, "val_polar":self.polar_loss_val, "val_scale":self.scale_loss_val}
145-
# #
146-
# data = {"polar":np.concatenate(self.polar_vector, axis=0).squeeze(),"scale":np.concatenate(self.scale_vector, axis=0).squeeze(),"label":np.concatenate(self.label, axis=0).squeeze(),"label_pred":np.concatenate(self.label_pred, axis=0).squeeze()}
147-
# # def to_pickle(obj, path):
148-
# # with open(path, 'wb') as f:
149-
# # pickle.dump(obj, f)
150-
# # to_pickle(data,"cluster_7.pkl")
151142

152143
def eval(self, mode=None, to_print=False):
153144
assert (mode is not None)
@@ -160,16 +151,13 @@ def eval(self, mode=None, to_print=False):
160151
dataloader = self.dev_data_loader
161152
elif mode == "test":
162153
dataloader = self.test_data_loader
163-
164154
if to_print:
165155
self.model.load_state_dict(torch.load(
166156
f'checkpoints/model_{self.train_config.name}.std'))
167157

168158
with torch.no_grad():
169-
170159
for batch in dataloader:
171160
y_tilde, y = self.model_input2output(batch)
172-
173161
loss = self.loss_function(y_tilde, y)
174162

175163
eval_loss.append(loss.item())
@@ -193,8 +181,6 @@ def multiclass_acc(self, preds, truths):
193181
:param truths: Float/int array representing the groundtruth classes, dimension (N,)
194182
:return: Classification accuracy
195183
"""
196-
# print(Counter(np.round(preds)))
197-
# print(Counter(np.round(truths)))
198184
return np.sum(np.round(preds) == np.round(truths)) / float(len(truths))
199185

200186
def calc_metrics(self, y_true, y_pred, mode=None, to_print=False):
@@ -229,7 +215,6 @@ def calc_metrics(self, y_true, y_pred, mode=None, to_print=False):
229215
print("mae: ", mae)
230216
print("corr: ", corr)
231217
print("mult_acc5: ", mult_a5)
232-
233218
print("Classification Report (pos/neg) :")
234219
print(classification_report(binary_truth, binary_preds, digits=5))
235220
print("Accuracy (pos/neg) ", accuracy_score(binary_truth, binary_preds))
@@ -248,5 +233,4 @@ def calc_metrics(self, y_true, y_pred, mode=None, to_print=False):
248233
print("F1:", f_score2)
249234

250235
print("Accuracy (non-neg/neg) ", accuracy_score(binary_truth, binary_preds))
251-
252236
return accuracy_score(binary_truth, binary_preds)

train.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@
66
import torch
77
import warnings
88

9-
109
warnings.filterwarnings("ignore") # ignore the warning
1110

1211
if __name__ == '__main__':
13-
1412
# Setting random seed
1513
random_name = str(random())
1614
random_seed = 5546
@@ -32,7 +30,6 @@
3230
dev_data_loader = get_loader(dev_config, shuffle = False)
3331
test_data_loader = get_loader(test_config, shuffle = False)
3432

35-
3633
# Solver is a wrapper for model traiing and testing
3734
# solver = Solver
3835
solver = Solver(train_config, train_data_loader, dev_data_loader, test_data_loader, is_train=True)

0 commit comments

Comments
 (0)