-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcxr_dataset.py
85 lines (71 loc) · 2.74 KB
/
cxr_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
import os
from PIL import Image
###########################################################
# EXTERNAL CLASS
###########################################################
class CXRDataset(Dataset):
def __init__(
self,
path_to_images,
fold,
transform=None,
sample=0,
finding="any",
starter_images=False):
self.transform = transform
self.path_to_images = path_to_images
self.df = pd.read_csv("nih_labels.csv")
self.df = self.df[self.df['fold'] == fold]
if(starter_images):
starter_images = pd.read_csv("starter_images.csv")
self.df=pd.merge(left=self.df,right=starter_images, how="inner",on="Image Index")
# can limit to sample, useful for testing
# if fold == "train" or fold =="val": sample=500
if(sample > 0 and sample < len(self.df)):
self.df = self.df.sample(sample)
if not finding == "any": # can filter for positive findings of the kind described; useful for evaluation
if finding in self.df.columns:
if len(self.df[self.df[finding] == 1]) > 0:
self.df = self.df[self.df[finding] == 1]
else:
print("No positive cases exist for "+LABEL+", returning all unfiltered cases")
else:
print("cannot filter on finding " + finding +
" as not in data - please check spelling")
self.df = self.df.set_index("Image Index")
self.PRED_LABEL = [
'Atelectasis',
'Cardiomegaly',
'Effusion',
'Infiltration',
'Mass',
'Nodule',
'Pneumonia',
'Pneumothorax',
'Consolidation',
'Edema',
'Emphysema',
'Fibrosis',
'Pleural_Thickening',
'Hernia']
RESULT_PATH = "results/"
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
image = Image.open(
os.path.join(
self.path_to_images,
self.df.index[idx]))
image = image.convert('RGB')
label = np.zeros(len(self.PRED_LABEL), dtype=int)
for i in range(0, len(self.PRED_LABEL)):
# can leave zero if zero, else make one
if(self.df[self.PRED_LABEL[i].strip()].iloc[idx].astype('int') > 0):
label[i] = self.df[self.PRED_LABEL[i].strip()
].iloc[idx].astype('int')
if self.transform:
image = self.transform(image)
return (image, label,self.df.index[idx])