-
Notifications
You must be signed in to change notification settings - Fork 0
/
making_data.py
95 lines (79 loc) · 2.82 KB
/
making_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# -*- coding: utf-8 -*-
"""making_data.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/18YNVjWRkpHzQD-p85uwxPqHivUdmjocy
"""
import os
from glob import glob
from PIL import Image
import numpy as np
import pickle
from tqdm import tqdm
num_classes = 5
img_size = 512
train_mask_paths = glob('Gleason_masks_train/*.png')
valid_mask_paths = glob('Gleason_masks_test/*.png')
print(os.path.exists('../harvard_data/TMA_Images'))
print(train_mask_paths[0].split('\\')[-1].split('.')[0].split('_')[1])
def get_class(rgb):
'''
takes in rgb values of the pixel and returns the class of the pixel
'''
rgb_n = rgb/255.0
# white
if rgb_n[0]>0.8 and rgb_n[1]>0.8 and rgb_n[2]>0.8 :
return 4
# red
elif rgb_n[0]>0.8 and rgb_n[1]<0.8 and rgb_n[2]<0.8 :
return 3
# yellow
elif rgb_n[0]>0.8 and rgb_n[1]>0.8 and rgb_n[2]<0.8 :
return 2
# green
elif rgb_n[0]<0.8 and rgb_n[1]>0.8 and rgb_n[2]<0.8 :
return 0
# blue
elif rgb_n[0]<0.8 and rgb_n[1]<0.8 and rgb_n[2]>0.8 :
return 1
else :
raise ValueError('Weird rgb combination! Did not match any of 5 classes.')
train_data = []
valid_data = []
train_files = []
valid_files = []
print(f'Training masks')
for i in tqdm(range(len(train_mask_paths))):
file_name = train_mask_paths[i].split('\\')[-1].split('.')[0][5:]
img_array = np.asarray(Image.open(f'../harvard_data/TMA_Images/{file_name}.jpg').resize((img_size,img_size)).convert('RGB'))
mask_array = np.asarray(Image.open(train_mask_paths[i]).resize((img_size,img_size)).convert('RGB'))
img_mask = np.zeros((num_classes,img_size,img_size))
for x in range(img_size):
for y in range(img_size):
pixel_class = get_class(mask_array[x,y,:])
img_mask[pixel_class,x,y]=1
train_data.append([img_array,img_mask])
train_files.append(file_name)
# i = 0
# print(f'Validation masks')
# for i in tqdm(range(len(valid_mask_paths))):
# mask_array = np.asarray(Image.open(valid_mask_paths[i]).resize((img_size,img_size)).convert('RGB'))
# img_mask = np.zeros((num_classes,img_size,img_size))
# file_name = valid_mask_paths[i].split('\\')[-1].split('.')[0]
# for x in range(img_size):
# for y in range(img_size):
# pixel_class = get_class(mask_array[x,y,:])
# img_mask[pixel_class,x,y]=1
# valid_data.append([img_array,img_mask])
# valid_files.append(file_name)
print(len(valid_data))
train_data = np.asarray(train_data)
# valid_data = np.asarray(valid_data)
np.save('correct_train_bundle',train_data)
# np.save('valid_bundle',valid_data)
tfile = open('correct_train_names','wb')
pickle.dump(train_files,tfile)
tfile.close()
# vfile = open('valid_names','wb')
# pickle.dump(valid_files,vfile)
# vfile.close()