-
Notifications
You must be signed in to change notification settings - Fork 12
/
dataloader.py
79 lines (65 loc) · 2.32 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# -*- coding: utf-8 -*- #
"""*********************************************************************************************"""
# FileName [ dataloader.py ]
# Synopsis [ Dataset wrapper and data loader ]
# Author [ Ting-Wei Liu (Andi611) ]
# Copyright [ Copyleft(c), NTUEE, NTU, Taiwan ]
"""*********************************************************************************************"""
###############
# IMPORTATION #
###############
import os
import json
import h5py
import torch
import numpy as np
from torch.utils import data
from collections import namedtuple
class DataLoader(object):
def __init__(self, dataset, batch_size=16):
self.dataset = dataset
self.n_elements = len(self.dataset[0])
self.batch_size = batch_size
self.index = 0
def all(self, size=1000):
samples = [self.dataset[self.index + i] for i in range(size)]
batch = [[s for s in sample] for sample in zip(*samples)]
batch_tensor = [torch.from_numpy(np.array(data)) for data in batch]
if self.index + 2 * self.batch_size >= len(self.dataset):
self.index = 0
else:
self.index += self.batch_size
return tuple(batch_tensor)
def __iter__(self):
return self
def __next__(self):
samples = [self.dataset[self.index + i] for i in range(self.batch_size)]
batch = [[s for s in sample] for sample in zip(*samples)]
batch_tensor = [torch.from_numpy(np.array(data)) for data in batch]
if self.index + 2 * self.batch_size >= len(self.dataset):
self.index = 0
else:
self.index += self.batch_size
return tuple(batch_tensor)
class Dataset(data.Dataset):
def __init__(self, h5_path, index_path, dset='train', seg_len=64, load_mel=False):
self.dataset = h5py.File(h5_path, 'r')
with open(index_path) as f_index:
self.indexes = json.load(f_index)
self.indexer = namedtuple('index', ['speaker', 'i', 't'])
self.seg_len = seg_len
self.dset = dset
self.load_mel = load_mel
def __getitem__(self, i):
index = self.indexes[i]
index = self.indexer(**index)
speaker = index.speaker
i, t = index.i, index.t
seg_len = self.seg_len
if self.load_mel:
data = [speaker, self.dataset[f'{self.dset}/{i}/lin'][t:t+seg_len], self.dataset[f'{self.dset}/{i}/mel'][t:t+seg_len]]
else:
data = [speaker, self.dataset[f'{self.dset}/{i}/lin'][t:t+seg_len]]
return tuple(data)
def __len__(self):
return len(self.indexes)