-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresnet_model.py
116 lines (78 loc) · 3.47 KB
/
resnet_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import pandas as pd
import torch
import time
from glob import glob
from PIL import Image
from numpy import asarray
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
def from_path_to_feature(path):
image = Image.open(path)
data = asarray(image)
return data.reshape(-1)
#Imported dataset
class ResizedDataset(Dataset):
def __init__(self):
super().__init__()
self.df = self.get_data()
self.label_dict = {}
def get_data(self):
dframe = pd.read_csv("/Users/macbook/Downloads/Products.csv", index_col=0).dropna().rename(columns={'id': 'product_id'})
images_df = pd.read_csv("/Users/macbook/Downloads/Images.csv", index_col=0)
merged_df = dframe.merge(images_df, on='product_id')
image_data_df = pd.DataFrame({'path': glob('/Users/macbook/Downloads/resized_images/*.jpg')})
image_data_df['id'] = image_data_df.path.str.split('/').str[-1].str.replace('.jpg', '', regex=False)
image_data_df = image_data_df.merge(merged_df[['category', 'id']], on='id')
image_data_df['category'] = image_data_df.category.str.split('/').str[0]
return image_data_df
def __getitem__(self, index):
img = self.df.iloc[index]
feature = from_path_to_feature(img.path)
if img.category not in self.label_dict:
self.label_dict[img.category] = len(self.label_dict) + 1
label = self.label_dict[img.category]
return (torch.tensor(feature), label)
def __len__(self):
return len(self.df)
torchdset = ResizedDataset()
imagedsloader = DataLoader(torchdset, batch_size=10, shuffle=True)
#Pre-trained RESNET50 model
class ImageClassifier(torch.nn.Module):
def __init__(self):
super().__init__()
self.resnet50 = torch.hub.load(
'NVIDIA/DeepLearningExamples:torchhub', 'nvidia_resnet50', pretrained=True)
self.resnet50.fc = torch.nn.Linear(2048, 9)
weight = self.resnet50.conv1.weight.clone()
self.resnet50.conv1 = nn.Conv2d(10, 1, kernel_size=3, stride=2, padding=3, bias=False)
with torch.no_grad():
self.resnet50.conv1.weight[:, :7] = weight
self.resnet50.conv1.weight[:, 3] = self.resnet50.conv1.weight[:, 0]
def forward(self, X):
return self.resnet50(X)
#Training the model
def train(model, dataloader, epochs=10):
optimiser = torch.optim.SGD(model.parameters(), lr=0.001)
for epoch in range(epochs):
for batch in dataloader:
feature, label = batch
feature = feature.unsqueeze(1)
print(feature.shape)
prediction = model(feature)
loss = F.cross_entropy(prediction, label)
loss.backward()
print(loss.item())
optimiser.step()
optimiser.zero_grad()
batch += 1
#Putting everything together for the training loop
if __name__ == '__main__':
torchdset = ResizedDataset()
imagedsloader = DataLoader(torchdset, batch_size=10, shuffle=True)
model = ImageClassifier()
train(model, imagedsloader)
torch.save(model.state_dict(), f'/Users/macbook/Documents/GitHub/facebook-marketplaces-recommendation-ranking-system/model_evaluation/{time.time}')
sd = (model.state_dict())
torch.save(sd['fc.weight'], f'/Users/macbook/Documents/GitHub/facebook-marketplaces-recommendation-ranking-system/model_evaluation/weights/{time.time}')