-
Notifications
You must be signed in to change notification settings - Fork 1
/
dataloading.py
82 lines (61 loc) · 2.34 KB
/
dataloading.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from torch.utils.data import Dataset
from tqdm import tqdm
class SentenceDataset(Dataset):
"""
Our custom PyTorch Dataset, for preparing strings of text (sentences)
What we have to do is to implement the 2 abstract methods:
- __len__(self): in order to let the DataLoader know the size
of our dataset and to perform batching, shuffling and so on...
- __getitem__(self, index): we have to return the properly
processed data-item from our dataset with a given index
"""
def __init__(self, X, y, word2idx):
"""
In the initialization of the dataset we will have to assign the
input values to the corresponding class attributes
and preprocess the text samples
-Store all meaningful arguments to the constructor here for debugging
and for usage in other methods
-Do most of the heavy-lifting like preprocessing the dataset here
Args:
X (list): List of training samples
y (list): List of training labels
word2idx (dict): a dictionary which maps words to indexes
"""
# self.data = X
# self.labels = y
# self.word2idx = word2idx
# EX2
raise NotImplementedError
def __len__(self):
"""
Must return the length of the dataset, so the dataloader can know
how to split it into batches
Returns:
(int): the length of the dataset
"""
return len(self.data)
def __getitem__(self, index):
"""
Returns the _transformed_ item from the dataset
Args:
index (int):
Returns:
(tuple):
* example (ndarray): vector representation of a training example
* label (int): the class label
* length (int): the length (tokens) of the sentence
Examples:
For an `index` where:
::
self.data[index] = ['this', 'is', 'really', 'simple']
self.target[index] = "neutral"
the function will have to return something like:
::
example = [ 533 3908 1387 649 0 0 0 0]
label = 1
length = 4
"""
# EX3
# return example, label, length
raise NotImplementedError