Skip to content

Commit

Permalink
fix sampler bugs and update dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
bubbliiiing committed Apr 15, 2022
1 parent 5524b5e commit 71c94df
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
6 changes: 6 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,12 @@
gen.dataset.epoch_now = epoch
gen_val.dataset.epoch_now = epoch

if distributed:
train_sampler.set_epoch(epoch)

set_optimizer_lr(optimizer, lr_scheduler_func, epoch)

fit_one_epoch(model_train, model, yolo_loss, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, fp16, scaler, save_period, save_dir, local_rank)

if local_rank == 0:
loss_history.writer.close()
6 changes: 4 additions & 2 deletions utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset

Expand Down Expand Up @@ -354,5 +355,6 @@ def yolo_dataset_collate(batch):
for img, box in batch:
images.append(img)
bboxes.append(box)
images = np.array(images)
return images, bboxes
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes]
return images, bboxes
14 changes: 4 additions & 10 deletions utils/utils_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,8 @@ def fit_one_epoch(model_train, model, yolo_loss, loss_history, optimizer, epoch,
images, targets = batch[0], batch[1]
with torch.no_grad():
if cuda:
images = torch.from_numpy(images).type(torch.FloatTensor).cuda()
targets = [torch.from_numpy(ann).type(torch.FloatTensor).cuda() for ann in targets]
else:
images = torch.from_numpy(images).type(torch.FloatTensor)
targets = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in targets]
images = images.cuda()
targets = [ann.cuda() for ann in targets]
#----------------------#
# 清零梯度
#----------------------#
Expand Down Expand Up @@ -94,11 +91,8 @@ def fit_one_epoch(model_train, model, yolo_loss, loss_history, optimizer, epoch,
images, targets = batch[0], batch[1]
with torch.no_grad():
if cuda:
images = torch.from_numpy(images).type(torch.FloatTensor).cuda()
targets = [torch.from_numpy(ann).type(torch.FloatTensor).cuda() for ann in targets]
else:
images = torch.from_numpy(images).type(torch.FloatTensor)
targets = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in targets]
images = images.cuda()
targets = [ann.cuda() for ann in targets]
#----------------------#
# 清零梯度
#----------------------#
Expand Down

0 comments on commit 71c94df

Please sign in to comment.