Skip to content

Commit 65998e0

Browse files
committed
update main and demo script for new structure
1 parent bfad059 commit 65998e0

File tree

2 files changed

+56
-55
lines changed

2 files changed

+56
-55
lines changed

demo.py

Lines changed: 40 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,62 @@
1-
import cv2
21
import click
2+
import cv2
3+
import onegan
34
import torch
4-
import torch.nn as nn
5-
import torchvision.transforms as transforms
6-
import numpy as np
5+
import torchvision.transforms as T
6+
from PIL import Image
77

8-
from training.models import VggFCN
9-
from tools import timeit, label_colormap
8+
from trainer.賣扣老師 import build_resnet101_FCN
109

1110
torch.backends.cudnn.benchmark = True
1211

1312

14-
class Demo():
15-
16-
transform = transforms.Compose([
17-
transforms.ToTensor(),
18-
transforms.Normalize(
19-
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
20-
])
13+
class Predictor:
2114

2215
def __init__(self, input_size, weight=None):
23-
self.input_size = input_size
24-
self.num_class = 5
25-
self.model = self.load_model(weight)
26-
self.cmap = self.create_camp()
27-
28-
def create_camp(self):
29-
return label_colormap(self.num_class + 1)[1:]
30-
31-
@timeit
32-
def load_model(self, weight):
33-
model = nn.DataParallel(
34-
VggFCN(num_classes=5, input_size=self.input_size,
35-
pretrained=False)).cuda()
36-
model.load_state_dict(torch.load(weight))
37-
return model
38-
39-
@timeit
16+
self.model = self.build_model(weight)
17+
self.colorizer = onegan.extension.Colorizer(
18+
colors=[
19+
[249, 69, 93], [255, 229, 170], [144, 206, 181],
20+
[81, 81, 119], [241, 247, 210]])
21+
self.transform = T.Compose([
22+
T.Resize(input_size),
23+
T.ToTensor(),
24+
T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
25+
])
26+
27+
def build_model(self, weight_path, joint_class=False):
28+
model = build_resnet101_FCN(pretrained=False, nb_classes=37, stage_2=True, joint_class=joint_class)
29+
weight = onegan.utils.export_checkpoint_weight(weight_path)
30+
model.load_state_dict(weight)
31+
model.eval()
32+
return model.cuda()
33+
34+
@onegan.utils.timeit
4035
def process(self, raw):
4136

42-
def output_label(output):
43-
label = np.zeros((*self.input_size, 3))
44-
for lbl in range(self.num_class):
45-
label[output.squeeze() == lbl] = self.cmap[lbl]
46-
return label
47-
48-
img = cv2.resize(raw, self.input_size)
49-
batched_img = self.transform(img).unsqueeze(0).cuda()
50-
51-
output = self.model.module.predict(batched_img)
52-
label = output_label(output)
37+
def _batched_process(batched_img):
38+
score, _ = self.model(onegan.utils.to_var(batched_img))
39+
_, output = torch.max(score, 1)
5340

54-
label = cv2.resize(label, (raw.shape[1], raw.shape[0]))
41+
image = (batched_img / 2 + .5)
42+
layout = self.colorizer.apply(output.data.cpu())
43+
return image * .6 + layout * .4
5544

56-
return raw.astype(np.float32) / 255 * 0.5 + label
45+
img = Image.fromarray(raw)
46+
batched_img = self.transform(img).unsqueeze(0)
47+
canvas = _batched_process(batched_img)
48+
result = canvas.squeeze().permute(1, 2, 0).numpy()
49+
return cv2.resize(result, (raw.shape[1], raw.shape[0]))
5750

5851

5952
@click.command()
6053
@click.option('--device', default=0)
61-
@click.option('--video', default='')
62-
@click.option('--weight', default='output/weight/vgg_bn_new/24.pth')
63-
@click.option('--input_size', default=(404, 404), type=(int, int))
54+
@click.option('--video', type=click.Path(exists=True))
55+
@click.option('--weight', type=click.Path(exists=True))
56+
@click.option('--input_size', default=(320, 320), type=(int, int))
6457
def main(device, video, weight, input_size):
6558

66-
demo = Demo(input_size, weight=weight)
59+
demo = Predictor(input_size, weight=weight)
6760

6861
reader = video if video else device
6962
cap = cv2.VideoCapture(reader)

main.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,21 @@ def create_dataset(args):
2323
return Dataset('train', args=args).to_loader(args=args), Dataset('val', args=args).to_loader(args=args)
2424

2525

26+
def create_model(args):
27+
return {
28+
'vgg': lambda: VggFCN(
29+
num_class=args.num_class, input_size=(args.image_size, args.image_size), pretrained=True, base='vgg16_bn'),
30+
'lavgg': lambda: LaFCN(
31+
num_classes=args.num_class, input_size=None, pretrained=True, base='vgg16_bn'),
32+
'resnet': lambda: ResFCN(
33+
num_classes=args.num_class, num_room_types=11, pretrained=True, base='resnet101'),
34+
'drn': lambda: DilatedResFCN(
35+
num_classes=args.num_class, num_room_types=11, pretrained=True, base='resnet101'),
36+
'mike': lambda: build_resnet101_FCN(
37+
pretrained=True, nb_classes=37, stage_2=True, joint_class=not args.disjoint_class)
38+
}[args.arch]()
39+
40+
2641
def create_optim(args, model, optim='sgd'):
2742
return {
2843
'adam': lambda: torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999)),
@@ -51,14 +66,7 @@ def main(args):
5166
log.info(''.join([f'\n-- {k}: {v}' for k, v in vars(args).items()]))
5267

5368
train_loader, val_loader = create_dataset(args)
54-
55-
model = {
56-
'vgg': lambda args: VggFCN(num_class=args.num_class, input_size=(args.image_size, args.image_size), pretrained=True, base='vgg16_bn'),
57-
'lavgg': lambda args: LaFCN(num_classes=args.num_class, input_size=None, pretrained=True, base='vgg16_bn'),
58-
'resnet': lambda args: ResFCN(num_classes=args.num_class, num_room_types=11, pretrained=True, base='resnet101'),
59-
'drn': lambda args: DilatedResFCN(num_classes=args.num_class, num_room_types=11, pretrained=True, base='resnet101'),
60-
'mike': lambda args: build_resnet101_FCN(pretrained=True, nb_classes=37, stage_2=True, joint_class=not args.disjoint_class)
61-
}[args.arch](args)
69+
model = create_model(args)
6270

6371
if args.phase == 'train':
6472
training_estimator = core.training_estimator(

0 commit comments

Comments
 (0)