-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathdata.py
55 lines (47 loc) · 1.59 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# -*- coding: utf-8 -*-
# @Author : Magic
# @Time : 2019/7/4 12:01
# @File : data.py
from config import config_dict
from dataset import SenseData, SenseDataTest
from torchvision import transforms
img_train = config_dict['data_dir_train']
img_val = config_dict['data_dir_val']
img_test = config_dict['data_dir_test']
def train_augs():
return transforms.Compose([
transforms.RandomResizedCrop(config_dict['im_size']),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]
)
])
def val_augs():
return transforms.Compose([
transforms.Resize(256, interpolation=2),
transforms.CenterCrop(config_dict['im_size']),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]
)
])
def test_augs():
return transforms.Compose([
transforms.Resize(256, interpolation=2),
transforms.CenterCrop(config_dict['im_size']),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]
)
])
def get_train_data(img_path=img_train, transform = train_augs()):
return SenseData(img_path, transform)
def get_val_data(img_path=img_val, transform = val_augs()):
return SenseData(img_path, transform)
def get_test_data(img_path=img_test, transform = test_augs()):
return SenseDataTest(img_path, transform)