-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloader.py
64 lines (51 loc) · 1.86 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# coding: utf-8
import os
import numpy as np
import torch
import torch.utils.data
class MPIIGazeDataset(torch.utils.data.Dataset):
def __init__(self, subject_id, dataset_dir):
path = os.path.join(dataset_dir, '{}.npz'.format(subject_id))
with np.load(path) as fin:
self.images = fin['image']
self.poses = fin['pose']
self.gazes = fin['gaze']
self.length = len(self.images)
self.images = torch.unsqueeze(torch.from_numpy(self.images), 1)
self.poses = torch.from_numpy(self.poses)
self.gazes = torch.from_numpy(self.gazes)
def __getitem__(self, index):
return self.images[index], self.poses[index], self.gazes[index]
def __len__(self):
return self.length
def __repr__(self):
return self.__class__.__name__
def get_loader(dataset_dir, test_subject_id, batch_size, num_workers, use_gpu):
assert os.path.exists(dataset_dir)
assert test_subject_id in range(15)
subject_ids = ['p{:02}'.format(index) for index in range(15)]
test_subject_id = subject_ids[test_subject_id]
train_dataset = torch.utils.data.ConcatDataset([
MPIIGazeDataset(subject_id, dataset_dir) for subject_id in subject_ids
if subject_id != test_subject_id
])
test_dataset = MPIIGazeDataset(test_subject_id, dataset_dir)
assert len(train_dataset) == 42000
assert len(test_dataset) == 3000
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=use_gpu,
drop_last=True,
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=False,
pin_memory=use_gpu,
drop_last=False,
)
return train_loader, test_loader