-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdata_utils.py
More file actions
90 lines (83 loc) · 3.61 KB
/
data_utils.py
File metadata and controls
90 lines (83 loc) · 3.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import pandas as pd
from torch.utils.data import Dataset
import cv2
import torch
from utils import get_transforms,DatasetName
class UltraMNISTDataset(Dataset):
def __init__(self,df,root_dir,transforms=None):
self.df = df
self.root_dir = root_dir
self.transforms = transforms
def __len__(self):
return len(self.df)
def __getitem__(self,index):
image_id = self.df.iloc[index].image_id
digit_sum = self.df.iloc[index].digit_sum
image = cv2.imread(f"{self.root_dir}/{image_id}.jpeg")
if self.transforms is not None:
image = self.transforms(image)
return image, torch.tensor(digit_sum)
class PandasDataset(Dataset):
def __init__(self,df,root_dir,transforms=None):
self.df = df
self.root_dir = root_dir
self.transforms = transforms
def __len__(self):
return len(self.df)
def __getitem__(self,index):
image_id = self.df.iloc[index].image_id
label = self.df.iloc[index].isup_grade
image = cv2.imread(f"{self.root_dir}/{image_id}.png")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.transforms is not None:
image = self.transforms(image)
return image, torch.tensor(label)
def get_train_val_dataset(train_csv_path,
val_csv_path,
sanity_check,
sanity_data_len,
train_root_dir,
val_root_dir,
image_size,
mean,
std,
print_lengths=True,
dataset_name=DatasetName.PANDAS):
transforms_dataset = get_transforms(image_size=image_size,mean=mean,std=std)
assert dataset_name in [DatasetName.UMNIST,DatasetName.PANDAS], 'Not a valid dataset name'
if dataset_name == DatasetName.PANDAS:
df = pd.read_csv(train_csv_path)
train_df = df[df['kfold']!=0]
val_df = df[df['kfold']==0]
if sanity_check:
train_df = train_df[:sanity_data_len]
val_df = val_df[:sanity_data_len]
if print_lengths:
print(f"Train set length: {len(train_df)}, validation set length: {len(val_df)}")
train_dataset = PandasDataset(train_df,train_root_dir,transforms_dataset)
validation_dataset = PandasDataset(val_df,val_root_dir,transforms_dataset)
elif dataset_name == DatasetName.UMNIST:
train_df = pd.read_csv(train_csv_path)
train_df = train_df.sample(frac=1).reset_index(drop=True)
val_df = pd.read_csv(val_csv_path)
val_df = val_df.sample(frac=1).reset_index(drop=True)
if sanity_check:
train_df = train_df[:sanity_data_len]
val_df = val_df[:sanity_data_len]
if print_lengths:
print(f"Train set length: {len(train_df)}, validation set length: {len(val_df)}")
train_dataset = UltraMNISTDataset(train_df,train_root_dir,transforms_dataset)
validation_dataset = UltraMNISTDataset(val_df,val_root_dir,transforms_dataset)
return train_dataset, validation_dataset
class PatchDataset(Dataset):
def __init__(self,images,num_patches,stride,patch_size):
self.images = images
self.num_patches = num_patches
self.stride = stride
self.patch_size = patch_size
def __len__(self):
return self.num_patches ** 2
def __getitem__(self,choice):
i = choice%self.num_patches
j = choice//self.num_patches
return self.images[:,:,self.stride*i:self.stride*i+self.patch_size,self.stride*j:self.stride*j+self.patch_size], choice