Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

converted to PyTorch Lightning #2

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 33 additions & 27 deletions Code/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import os
import torch
import time
from model import FGSBIR_Model
from dataset import get_dataloader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import argparse


# PyTorch Lightning Modules
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.profilers import SimpleProfiler


if __name__ == "__main__":
Expand All @@ -28,29 +33,30 @@
dataloader_Train, dataloader_Test = get_dataloader(hp)
print(hp)


model = FGSBIR_Model(hp)
model.to(device)
# model.load_state_dict(torch.load('VGG_ShoeV2_model_best.pth', map_location=device))
step_count, top1, top10 = -1, 0, 0

for i_epoch in range(hp.max_epoch):
for batch_data in dataloader_Train:
step_count = step_count + 1
start = time.time()
model.train()
loss = model.train_model(batch=batch_data)

if step_count % hp.print_freq_iter == 0:
print('Epoch: {}, Iteration: {}, Loss: {:.5f}, Top1_Accuracy: {:.5f}, Top10_Accuracy: {:.5f}, Time: {}'.format
(i_epoch, step_count, loss, top1, top10, time.time()-start))

if step_count % hp.eval_freq_iter == 0:
with torch.no_grad():
top1_eval, top10_eval = model.evaluate(dataloader_Test)
print('results : ', top1_eval, ' / ', top10_eval)

if top1_eval > top1:
torch.save(model.state_dict(), hp.backbone_name + '_' + hp.dataset_name + '_model_best.pth')
top1, top10 = top1_eval, top10_eval
print('Model Updated')
exp_name = '%s-%s'%(hp.backbone_name, hp.dataset_name)
logger = WandbLogger(project="Baseline FGSBIR", name=exp_name)

checkpoint_callback = ModelCheckpoint(
monitor='top1', mode='max', dirpath=exp_name,
filename=hp.backbone_name, save_last=True)

if os.path.exists(os.path.join(exp_name, hp.backbone_name, 'last.ckpt')):
ckpt_path = os.path.join(exp_name, hp.backbone_name, 'last.ckpt')
model = FGSBIR_Model(hp).load_from_checkpoint(
checkpoint_path=ckpt_path)
else:
model = FGSBIR_Model(hp)
ckpt_path = None

profiler = SimpleProfiler(
dirpath=os.path.join(exp_name, hp.dataset_name),
filename='perf-logs')

trainer = Trainer(logger=logger,
accelerator='gpu', devices=1, accumulate_grad_batches=None,
benchmark=False, deterministic=False, detect_anomaly=False,
callbacks=[checkpoint_callback], check_val_every_n_epoch=1,
log_every_n_steps=10, overfit_batches=0.0, limit_val_batches=1.0,
max_epochs=hp.max_epoch, enable_model_summary=True, profiler=profiler)

trainer.fit(model, dataloader_Train, dataloader_Test, ckpt_path=ckpt_path)
65 changes: 34 additions & 31 deletions Code/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,50 +5,52 @@
import torch
import time
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import pytorch_lightning as pl

class FGSBIR_Model(nn.Module):

class FGSBIR_Model(pl.LightningModule):
def __init__(self, hp):
super(FGSBIR_Model, self).__init__()
self.sample_embedding_network = eval(hp.backbone_name + '_Network(hp)')
self.loss = nn.TripletMarginLoss(margin=0.2)
self.sample_train_params = self.sample_embedding_network.parameters()
self.optimizer = optim.Adam(self.sample_train_params, hp.learning_rate)
self.hp = hp

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hp.learning_rate)
return optimizer

def train_model(self, batch):
self.train()
self.optimizer.zero_grad()
def training_step(self, batch, batch_idx):

positive_feature = self.sample_embedding_network(batch['positive_img'].to(device))
negative_feature = self.sample_embedding_network(batch['negative_img'].to(device))
sample_feature = self.sample_embedding_network(batch['sketch_img'].to(device))
positive_feature = self.sample_embedding_network(batch['positive_img'])
negative_feature = self.sample_embedding_network(batch['negative_img'])
sample_feature = self.sample_embedding_network(batch['sketch_img'])

loss = self.loss(sample_feature, positive_feature, negative_feature)
loss.backward()
self.optimizer.step()

return loss.item()

def evaluate(self, datloader_Test):
self.log('train_loss', loss)
return loss

def validation_step(self, batch, batch_idx):
sketch_feat, positive_feat = self.test_forward(batch)
return sketch_feat, positive_feat, batch['sketch_path'], batch['positive_path']

def validation_epoch_end(self, validation_step_outputs):
Image_Feature_ALL = []
Image_Name = []
Sketch_Feature_ALL = []
Sketch_Name = []
start_time = time.time()
self.eval()
for i_batch, sanpled_batch in enumerate(datloader_Test):
sketch_feature, positive_feature= self.test_forward(sanpled_batch)
Sketch_Feature_ALL.extend(sketch_feature)
Sketch_Name.extend(sanpled_batch['sketch_path'])

for i_num, positive_name in enumerate(sanpled_batch['positive_path']):
if positive_name not in Image_Name:
Image_Name.append(sanpled_batch['positive_path'][i_num])
Image_Feature_ALL.append(positive_feature[i_num])

for sketch_feat, positive_feat, sketch_path, positive_path in validation_step_outputs:
Sketch_Feature_ALL.extend(sketch_feat)
Sketch_Name.extend(sketch_path)

for i_num, positive_name in enumerate(positive_path):
if positive_name not in Image_Name:
Image_Name.append(positive_path[i_num])
Image_Feature_ALL.append(positive_feat[i_num])

rank = torch.zeros(len(Sketch_Name))
Image_Feature_ALL = torch.stack(Image_Feature_ALL)

Expand All @@ -66,13 +68,14 @@ def evaluate(self, datloader_Test):
top1 = rank.le(1).sum().numpy() / rank.shape[0]
top10 = rank.le(10).sum().numpy() / rank.shape[0]

print('Time to EValuate:{}'.format(time.time() - start_time))
return top1, top10
self.log('top1', top1)
self.log('top10', top10)
print ('Evaluation metrics: Top1 %.4f Top10 %.4f'%(top1, top10))

def test_forward(self, batch): # this is being called only during evaluation
sketch_feature = self.sample_embedding_network(batch['sketch_img'].to(device))
positive_feature = self.sample_embedding_network(batch['positive_img'].to(device))
return sketch_feature.cpu(), positive_feature.cpu()
sketch_feature = self.sample_embedding_network(batch['sketch_img'])
positive_feature = self.sample_embedding_network(batch['positive_img'])
return sketch_feature, positive_feature