Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
layumi authored Jul 2, 2024
1 parent 6bf8a6f commit 9a66c11
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,11 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
#best_acc = 0.0
warm_up = 0.1 # We start from the 0.1*lrRate
warm_iteration = round(dataset_sizes['train']/opt.batchsize)*opt.warm_epoch # first 5 epoch
embedding_size = model.classifier.linear.linear_num
if opt.arcface:
criterion_arcface = losses.ArcFaceLoss(num_classes=opt.nclasses, embedding_size=512)
criterion_arcface = losses.ArcFaceLoss(num_classes=opt.nclasses, embedding_size=embedding_size)
if opt.cosface:
criterion_cosface = losses.CosFaceLoss(num_classes=opt.nclasses, embedding_size=512)
criterion_cosface = losses.CosFaceLoss(num_classes=opt.nclasses, embedding_size=embedding_size)
if opt.circle:
criterion_circle = CircleLoss(m=0.25, gamma=32) # gamma = 64 may lead to a better result.
if opt.triplet:
Expand All @@ -237,7 +238,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
if opt.instance:
criterion_instance = InstanceLoss(gamma = opt.ins_gamma)
if opt.sphere:
criterion_sphere = losses.SphereFaceLoss(num_classes=opt.nclasses, embedding_size=512, margin=4)
criterion_sphere = losses.SphereFaceLoss(num_classes=opt.nclasses, embedding_size=embedding_size, margin=4)
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
# print('-' * 10)
Expand Down

0 comments on commit 9a66c11

Please sign in to comment.