-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCNN.py
122 lines (83 loc) · 3.06 KB
/
CNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import pandas as pd
import torch
from glob import glob
import cv2
from PIL import Image
from numpy import asarray
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
import torch.nn.functional as F
def from_path_to_feature(path):
image = Image.open(path)
data = asarray(image)
return data.reshape(-1)
class ResizedDataset(Dataset):
def __init__(self):
super().__init__()
self.df = self.get_data()
self.label_dict = {}
def get_data(self):
dframe = pd.read_csv("/Users/macbook/Downloads/Products.csv", index_col=0).dropna().rename(columns={'id': 'product_id'})
images_df = pd.read_csv("/Users/macbook/Downloads/Images.csv", index_col=0)
merged_df = dframe.merge(images_df, on='product_id')
image_data_df = pd.DataFrame({'path': glob('/Users/macbook/Downloads/resized_images/*.jpg')})
image_data_df['id'] = image_data_df.path.str.split('/').str[-1].str.replace('.jpg', '', regex=False)
image_data_df = image_data_df.merge(merged_df[['category', 'id']], on='id')
image_data_df['category'] = image_data_df.category.str.split('/').str[0]
return image_data_df
def __getitem__(self, index):
img = self.df.iloc[index]
feature = from_path_to_feature(img.path)
if img.category not in self.label_dict:
self.label_dict[img.category] = len(self.label_dict) + 1
label = self.label_dict[img.category]
return (torch.tensor(feature), label)
def __len__(self):
return len(self.df)
torchdset = ResizedDataset()
imagedsloader = DataLoader(torchdset, batch_size=10, shuffle=True)
'''
for batch in imagedsloader:
print(batch)
break
print(torchdset.label_dict)
'''
example = next(iter(imagedsloader))
torchdset.feature, torchdset.label = example
def train(model, epochs=10):
optimiser = torch.optim.SGD(model.parameters(), lr=0.001)
writer = SummaryWriter()
batch_idx = 0
for epoch in range(epochs):
for batch in imagedsloader:
feature, label = batch
prediction = model(feature)
loss = F.cross_entropy(prediction, label)
loss.backward()
print(loss.item())
optimiser.step()
optimiser.zero_grad()
writer.add_scalar('Loss', loss.item(), batch_idx)
batch += 1
class CNN(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Conv2d(3, 20, 5),
torch.nn.ReLU(),
torch.nn.Conv2d(20, 10, 5),
torch.nn.Flatten(),
torch.nn.Linear(10 * 5 * 5, 150),
torch.nn.ReLU(),
torch.nn.Linear(150, 9),
torch.nn.Softmax()
)
def forward(self, X):
return self.layers(X)
if __name__ == '__main__':
torchdset = ResizedDataset()
imagedsloader = DataLoader(torchdset, batch_size=10, shuffle=True)
model = CNN()
train(model)