-
Notifications
You must be signed in to change notification settings - Fork 47
/
data_loader.py
91 lines (72 loc) · 2.83 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#!/usr/bin/env python3
from torch.utils.data import DataLoader, Dataset
import torch.autograd as autograd
import torch
import json
import csv
class AGNEWs(Dataset):
def __init__(self, label_data_path, alphabet_path, l0 = 1014):
"""Create AG's News dataset object.
Arguments:
label_data_path: The path of label and data file in csv.
l0: max length of a sample.
alphabet_path: The path of alphabet json file.
"""
self.label_data_path = label_data_path
self.l0 = l0
# read alphabet
self.loadAlphabet(alphabet_path)
self.load(label_data_path)
def __len__(self):
return len(self.label)
def __getitem__(self, idx):
X = self.oneHotEncode(idx)
y = self.y[idx]
return X, y
def loadAlphabet(self, alphabet_path):
with open(alphabet_path) as f:
self.alphabet = ''.join(json.load(f))
def load(self, label_data_path, lowercase = True):
self.label = []
self.data = []
with open(label_data_path, 'r') as f:
rdr = csv.reader(f, delimiter=',', quotechar='"')
# num_samples = sum(1 for row in rdr)
for index, row in enumerate(rdr):
self.label.append(int(row[0]))
txt = ' '.join(row[1:])
if lowercase:
txt = txt.lower()
self.data.append(txt)
self.y = torch.LongTensor(self.label)
def oneHotEncode(self, idx):
# X = (batch, 70, sequence_length)
X = torch.zeros(len(self.alphabet), self.l0)
sequence = self.data[idx]
for index_char, char in enumerate(sequence[::-1]):
if self.char2Index(char)!=-1:
X[self.char2Index(char)][index_char] = 1.0
return X
def char2Index(self, character):
return self.alphabet.find(character)
def getClassWeight(self):
num_samples = self.__len__()
label_set = set(self.label)
num_class = [self.label.count(c) for c in label_set]
class_weight = [num_samples/float(self.label.count(c)) for c in label_set]
return class_weight, num_class
if __name__ == '__main__':
label_data_path = 'data/ag_news_csv/test.csv'
alphabet_path = 'alphabet.json'
train_dataset = AGNEWs(label_data_path, alphabet_path)
train_loader = DataLoader(train_dataset, batch_size=64, num_workers=4, drop_last=False)
# size = 0
for i_batch, sample_batched in enumerate(train_loader):
if i_batch == 0:
print(sample_batched[0][0][0].shape)
# print(sample_batched)
# len(i_batch)
# print(sample_batched['label'].size())
# inputs = sample_batched['data']
# print(inputs.size())
# print('type(target): ', target)