-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathdatasets.py
55 lines (42 loc) · 1.79 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 12 10:11:09 2020
@author: NAT
"""
import torch
from torch.utils.data import Dataset
import json
import os
from PIL import Image
from utils import transform
class VOCDataset(Dataset):
def __init__(self, DataFolder, split):
"""
DataFolder: folder where data files are stored
split: split {"TRAIN", "TEST"}
"""
self.split = str(split.upper())
if self.split not in {"TRAIN", "TEST"}:
print("Param split not in {TRAIN, TEST}")
assert self.split in {"TRAIN", "TEST"}
self.DataFolder = DataFolder
#read data file from json file
with open(os.path.join(DataFolder, self.split+ '_images.json'), 'r') as j:
self.images = json.load(j)
with open(os.path.join(DataFolder, self.split+ '_objects.json'), 'r') as j:
self.objects = json.load(j)
assert len(self.images) == len(self.objects)
def __len__(self):
return len(self.images)
def __getitem__(self, i):
image = Image.open(self.images[i], mode= "r")
image = image.convert("RGB")
#Read objects in this image
objects = self.objects[i]
boxes = torch.FloatTensor(objects["boxes"])
labels = torch.LongTensor(objects['labels'])
difficulties = torch.ByteTensor(objects['difficulties'])
#Apply transforms
new_image, new_boxes, new_labels, new_difficulties = transform(image, boxes,
labels, difficulties, self.split)
return new_image, new_boxes, new_labels, new_difficulties