Skip to content

Commit

Permalink
batch renorm
Browse files Browse the repository at this point in the history
  • Loading branch information
SuperbTUM committed Mar 31, 2023
1 parent 93c67a3 commit 9f7edfa
Show file tree
Hide file tree
Showing 16 changed files with 160 additions and 26 deletions.
2 changes: 1 addition & 1 deletion modification_deepsort/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torchvision.transforms as transforms

from reid.SERes18_IBN import SEDense18_IBN
from reid.backbones.SERes18_IBN import SEDense18_IBN


# from .model import Net
Expand Down
Empty file added reid/__init__.py
Empty file.
File renamed without changes.
42 changes: 34 additions & 8 deletions reid/SERes18_IBN.py → reid/backbones/SERes18_IBN.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from torchvision import models
from torch.nn import functional as F

from batchrenorm import BatchRenormalization2D


# This can be applied as channel attention for gallery based on query
class SEBlock(nn.Module):
Expand Down Expand Up @@ -45,16 +47,19 @@ def __repr__(self):


class IBN(nn.Module):
def __init__(self, in_channels, ratio=0.5):
def __init__(self, in_channels, ratio=0.5, renorm=False):
"""
Some do instance norm, some do batch norm
Half do instance norm, half do batch norm
"""
super().__init__()
self.in_channels = in_channels
self.ratio = ratio
self.half = int(self.in_channels * ratio)
self.IN = nn.InstanceNorm2d(self.half, affine=True)
self.BN = nn.BatchNorm2d(self.in_channels - self.half)
if renorm:
self.BN = BatchRenormalization2D(self.in_channels - self.half)
else:
self.BN = nn.BatchNorm2d(self.in_channels - self.half)

def forward(self, x):
split = torch.split(x, self.half, 1)
Expand All @@ -69,22 +74,29 @@ class SEDense18_IBN(nn.Module):
Additionally, we would like to test the network with local average pooling
i.e. Divide into eight and concatenate them
"""
def __init__(self, num_class=751, needs_norm=True, gem=True, is_reid=False, PAP=False):
def __init__(self,
resnet18_pretrained=True,
num_class=751,
needs_norm=True,
gem=True,
renorm=False,
is_reid=False,
PAP=False):
super().__init__()
model = models.resnet18(pretrained=True)
model = models.resnet18(pretrained=resnet18_pretrained)
self.conv0 = model.conv1
self.bn0 = model.bn1
self.relu0 = model.relu
self.pooling0 = model.maxpool

model.layer1[0].bn1 = IBN(64)
model.layer1[0].bn1 = IBN(64, renorm=renorm)
self.basicBlock11 = model.layer1[0]
self.seblock1 = SEBlock(64)

self.basicBlock12 = model.layer1[1]
self.seblock2 = SEBlock(64)

model.layer2[0].bn1 = IBN(128)
model.layer2[0].bn1 = IBN(128, renorm=renorm)
self.basicBlock21 = model.layer2[0]
self.seblock3 = SEBlock(128)
self.ancillaryconv3 = nn.Conv2d(64, 128, 1, 2, 0)
Expand All @@ -93,7 +105,7 @@ def __init__(self, num_class=751, needs_norm=True, gem=True, is_reid=False, PAP=
self.basicBlock22 = model.layer2[1]
self.seblock4 = SEBlock(128)

model.layer3[0].bn1 = IBN(256)
model.layer3[0].bn1 = IBN(256, renorm=renorm)
self.basicBlock31 = model.layer3[0]
self.seblock5 = SEBlock(256)
self.ancillaryconv5 = nn.Conv2d(128, 256, 1, 2, 0)
Expand Down Expand Up @@ -203,3 +215,17 @@ def forward(self, x):
x = self.classifier(x)

return x, feature


def seres18_ibn(num_classes=751, pretrained=False, loss="triplet", **kwargs):
if loss == "triplet":
is_reid = False
elif loss == "softmax":
is_reid = True
else:
raise NotImplementedError
model = SEDense18_IBN(num_class=num_classes,
resnet18_pretrained=pretrained,
is_reid=is_reid,
**kwargs)
return model
9 changes: 9 additions & 0 deletions reid/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
__all__ = ["swin_transformer",
"vision_transformer",
"SERes18_IBN",
"plr_osnet",
"osnet",
"video_model",
"AGW_MODEL",
"model_dense_new"]
__version__ = "0.1.0"
File renamed without changes.
56 changes: 56 additions & 0 deletions reid/backbones/batchrenorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# credit to @mf1024
import torch
import torch.nn as nn


class BatchRenormalization2D(nn.Module):

def __init__(self, num_features, eps=1e-05, momentum=0.01, r_d_max_inc_step=0.0001):
super(BatchRenormalization2D, self).__init__()

self.eps = eps
self.momentum = torch.tensor(momentum)

self.gamma = torch.nn.Parameter(torch.ones((1, num_features, 1, 1)), requires_grad=True)
self.beta = torch.nn.Parameter(torch.zeros((1, num_features, 1, 1)), requires_grad=True)

self.running_avg_mean = torch.ones((1, num_features, 1, 1), requires_grad=False)
self.running_avg_std = torch.zeros((1, num_features, 1, 1), requires_grad=False)

self.max_r_max = 3.0
self.max_d_max = 5.0

self.r_max_inc_step = r_d_max_inc_step
self.d_max_inc_step = r_d_max_inc_step

self.r_max = torch.tensor(1.0)
self.d_max = torch.tensor(0.0)

def forward(self, x):

batch_ch_mean = torch.mean(x, dim=(0, 2, 3), keepdim=True)
batch_ch_std = torch.clamp(torch.std(x, dim=(0, 2, 3), keepdim=True), self.eps, 1e10)

if self.training:
r = torch.clamp(batch_ch_std / self.running_avg_std, 1.0 / self.r_max, self.r_max).data
d = torch.clamp((batch_ch_mean - self.running_avg_mean) / self.running_avg_std, -self.d_max, self.d_max).data

x = ((x - batch_ch_mean) * r) / batch_ch_std + d
x = self.gamma * x + self.beta

if self.r_max < self.max_r_max:
self.r_max += self.r_max_inc_step * x.shape[0]

if self.d_max < self.max_d_max:
self.d_max += self.d_max_inc_step * x.shape[0]

self.running_avg_mean = self.running_avg_mean + self.momentum * (
batch_ch_mean.data - self.running_avg_mean)
self.running_avg_std = self.running_avg_std + self.momentum * (
batch_ch_std.data - self.running_avg_std)

else:
x = (x - self.running_avg_mean) / self.running_avg_std
x = self.gamma * x + self.beta

return x
File renamed without changes.
File renamed without changes.
5 changes: 1 addition & 4 deletions reid/plr_osnet.py → reid/backbones/plr_osnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
from osnet import *
import copy
import random
import math
from attention_module import Attention_Module
from reid.backbones.attention_module import Attention_Module


class PLR_OSNet(nn.Module):
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
64 changes: 56 additions & 8 deletions reid/image_reid_train.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from train_prepare import *
import os
import torch.onnx
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm import tqdm
from torch.autograd import Variable

from plr_osnet import plr_osnet
from vision_transformer import vit_t
from swin_transformer import swin_t
from reid.backbones.plr_osnet import plr_osnet
from reid.backbones.SERes18_IBN import seres18_ibn
from reid.backbones.vision_transformer import vit_t
from reid.backbones.swin_transformer import swin_t
from train_utils import *
from dataset_market import Market1501

Expand Down Expand Up @@ -61,6 +60,48 @@ def __getitem__(self, item):
return detailed_info


def train_cnn(model, dataset, batch_size=8, epochs=25, num_classes=517, accelerate=False):
if params.ckpt and os.path.exists(params.ckpt):
model.eval()
model_state_dict = torch.load(params.ckpt)
model.load_state_dict(model_state_dict, strict=False)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=300, gamma=0.5)
loss_func = HybridLoss3(num_classes=num_classes)
dataloader = DataLoaderX(dataset, batch_size=batch_size, num_workers=4, shuffle=True, pin_memory=True)
if accelerate:
res_dict = accelerate_train(model, dataloader, optimizer, lr_scheduler)
model, dataloader, optimizer, lr_scheduler = res_dict["accelerated"]
accelerator = res_dict["accelerator"]
loss_stats = []
for epoch in range(epochs):
iterator = tqdm(dataloader)
for sample in iterator:
images, label = sample[:2]
optimizer.zero_grad()
images = images.cuda(non_blocking=True)
label = Variable(label).cuda(non_blocking=True)
embeddings, outputs = model(images)
loss = loss_func(embeddings, outputs, label)
loss_stats.append(loss.cpu().item())
nn.utils.clip_grad_norm_(model.parameters(), 10)
if accelerate:
accelerator.backward(loss)
else:
loss.backward()
optimizer.step()
lr_scheduler.step()
description = "epoch: {}, lr: {}, loss: {:.4f}".format(epoch, lr_scheduler.get_last_lr()[0], loss)
iterator.set_description(description)
model.eval()
torch.save(model.state_dict(), "cnn_net_checkpoint.pt")
to_onnx(model.module,
torch.randn(batch_size, 3, 256, 128, requires_grad=True, device="cuda"),
output_names=["embeddings", "outputs"])
return model, loss_stats


def train_plr_osnet(model, dataset, batch_size=8, epochs=25, num_classes=517, accelerate=False):
if params.ckpt and os.path.exists(params.ckpt):
model.eval()
Expand Down Expand Up @@ -228,7 +269,11 @@ def parser():
args.add_argument("--ckpt", help="where the checkpoint of vit is, can either be a onnx or pt", type=str,
default="vision_transformer_checkpoint.pt")
args.add_argument("--bs", type=int, default=64)
args.add_argument("--backbone", type=str, default="plr_osnet", choices=["plr_osnet", "vit", "swin_v1", "swin_v2"])
args.add_argument("--backbone", type=str, default="plr_osnet", choices=["seres18",
"plr_osnet",
"vit",
"swin_v1",
"swin_v2"])
args.add_argument("--epochs", type=int, default=50)
args.add_argument("--continual", action="store_true")
args.add_argument("--accelerate", action="store_true")
Expand All @@ -239,7 +284,7 @@ def parser():
params = parser()
dataset = Market1501(root="Market1501")

if params.backbone == "plr_osnet":
if params.backbone in ("plr_osnet", "seres18"):
# No need for cross-domain retrain
transform_train = transforms.Compose([
transforms.Resize((256, 128)),
Expand All @@ -252,7 +297,10 @@ def parser():
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
market_dataset = MarketDataset(dataset.train, transform_train)
model = plr_osnet(num_classes=dataset.num_train_pids, loss='triplet').cuda()
if params.backbone == "plr_osnet":
model = plr_osnet(num_classes=dataset.num_train_pids, loss='triplet').cuda()
else:
model = seres18_ibn(num_classes=dataset.num_train_pids, loss="triplet").cuda()
model = nn.DataParallel(model)
model, loss_stats = train_plr_osnet(model, market_dataset, params.bs, params.epochs, dataset.num_train_pids,
params.accelerate)
Expand Down
5 changes: 2 additions & 3 deletions reid/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
# experimental
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from copy import deepcopy
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset
from video_reid_train import TripletLoss, CenterLoss

import glob
from osnet import osnet_ibn_x1_0, OSNet
from reid.backbones.osnet import osnet_ibn_x1_0

# use deeplabv3_resnet50 instead of deeplabv3_resnet101 to reduce the model size
model = torch.hub.load('pytorch/vision:v0.8.0', 'deeplabv3_resnet50', pretrained=True)
Expand Down
3 changes: 1 addition & 2 deletions reid/video_reid_train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from train_prepare import *
import torch.onnx
from tqdm import tqdm
from torch.autograd import Variable
Expand All @@ -9,7 +8,7 @@
import argparse
import madgrad

from video_model import resnet18, resnet34, resnet50, resnet101, resnet152, resnet200
from reid.backbones.video_model import resnet50
from train_utils import *

cudnn.deterministic = True
Expand Down

0 comments on commit 9f7edfa

Please sign in to comment.