-
Notifications
You must be signed in to change notification settings - Fork 209
/
data_loader.py
98 lines (70 loc) · 3 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from torch.utils import data
import torch
import numpy as np
import pickle
import os
from multiprocessing import Process, Manager
class Utterances(data.Dataset):
"""Dataset class for the Utterances dataset."""
def __init__(self, root_dir, len_crop):
"""Initialize and preprocess the Utterances dataset."""
self.root_dir = root_dir
self.len_crop = len_crop
self.step = 10
metaname = os.path.join(self.root_dir, "train.pkl")
meta = pickle.load(open(metaname, "rb"))
"""Load data using multiprocessing"""
manager = Manager()
meta = manager.list(meta)
dataset = manager.list(len(meta)*[None])
processes = []
for i in range(0, len(meta), self.step):
p = Process(target=self.load_data,
args=(meta[i:i+self.step],dataset,i))
p.start()
processes.append(p)
for p in processes:
p.join()
self.train_dataset = list(dataset)
self.num_tokens = len(self.train_dataset)
print('Finished loading the dataset...')
def load_data(self, submeta, dataset, idx_offset):
for k, sbmt in enumerate(submeta):
uttrs = len(sbmt)*[None]
for j, tmp in enumerate(sbmt):
if j < 2: # fill in speaker id and embedding
uttrs[j] = tmp
else: # load the mel-spectrograms
uttrs[j] = np.load(os.path.join(self.root_dir, tmp))
dataset[idx_offset+k] = uttrs
def __getitem__(self, index):
# pick a random speaker
dataset = self.train_dataset
list_uttrs = dataset[index]
emb_org = list_uttrs[1]
# pick random uttr with random crop
a = np.random.randint(2, len(list_uttrs))
tmp = list_uttrs[a]
if tmp.shape[0] < self.len_crop:
len_pad = self.len_crop - tmp.shape[0]
uttr = np.pad(tmp, ((0,len_pad),(0,0)), 'constant')
elif tmp.shape[0] > self.len_crop:
left = np.random.randint(tmp.shape[0]-self.len_crop)
uttr = tmp[left:left+self.len_crop, :]
else:
uttr = tmp
return uttr, emb_org
def __len__(self):
"""Return the number of spkrs."""
return self.num_tokens
def get_loader(root_dir, batch_size=16, len_crop=128, num_workers=0):
"""Build and return a data loader."""
dataset = Utterances(root_dir, len_crop)
worker_init_fn = lambda x: np.random.seed((torch.initial_seed()) % (2**32))
data_loader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True,
worker_init_fn=worker_init_fn)
return data_loader