-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmerdataset.py
93 lines (72 loc) · 2.7 KB
/
merdataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import json
from collections import Counter
from torch.utils.data import Dataset
class MERDataset(Dataset):
def __init__(self,data_option='train',path='./'):
with open(path+'processed_KEMDy20.json','r') as file:
data = json.load(file)
train = list(range(1,33))
test = list(range(33,41))
self.emo_map = {'neutral': 0,
'happy': 1,
'surprise':2,
'angry': 3,
'sad':4,
'disqust': 5,
'fear': 6}
script = data['Sess01'].keys()
self.data = []
if data_option == 'train':
r = train
elif data_option == 'test':
r = test
for sess in r:
session = 'Sess'+'{0:0>2}'.format(sess)
for s in script:
self.data.extend(data[session][s])
del_idx = []
for idx,data in enumerate(self.data):
if ';' in data['Emotion']:
temp = data
emotions = data['Emotion'].split(';')
for emotion in emotions:
temp['Emotion'] = emotion
self.data.append(temp)
del_idx.append(idx)
for idx in del_idx[::-1]:
del self.data[idx]
for idx in range(len(self.data)):
label = self.emo_map[self.data[idx]['Emotion']]
self.data[idx]['label'] = label
def __len__(self):
return len(self.data)
def __getitem__(self,idx):
return self.data[idx]
def prepare_text_data(self,text_config):
K = text_config.K
for idx,data in enumerate(self.data):
dialogue = [data['utterance']]+data['history'][:K-1]
dialogue = '[SEP]'.join(dialogue)
self.data[idx]['dialogue'] = dialogue
def get_weight(self):
weight = Counter([data['label'] for data in self.data])
weight = [weight[i] for i in range(0,7)]
sum_ = len(self.data)
weight = [sum_/i for i in weight]
return weight
if __name__ == '__main__':
dataset = MERDataset(data_option='train',path='./data/')
weight = dataset.get_weight()
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
labels = [data['label'] for data in dataset]
counter = Counter(labels)
counter = {k:len(labels)/v for k,v in counter.items()}
import torch
weight = [counter[i] for i in labels]
sampler = WeightedRandomSampler(weight,len(weight))
dataloader = DataLoader(dataset,batch_size=16,sampler=sampler,collate_fn= lambda x: (x,torch.Tensor([i['label']for i in x])))
l = []
for batch_x, batch_y in dataloader:
l.extend(batch_y.tolist())
print(Counter(l))