|
1 |
| -import cv2 |
2 | 1 | import click
|
| 2 | +import cv2 |
| 3 | +import onegan |
3 | 4 | import torch
|
4 |
| -import torch.nn as nn |
5 |
| -import torchvision.transforms as transforms |
6 |
| -import numpy as np |
| 5 | +import torchvision.transforms as T |
| 6 | +from PIL import Image |
7 | 7 |
|
8 |
| -from training.models import VggFCN |
9 |
| -from tools import timeit, label_colormap |
| 8 | +from trainer.賣扣老師 import build_resnet101_FCN |
10 | 9 |
|
11 | 10 | torch.backends.cudnn.benchmark = True
|
12 | 11 |
|
13 | 12 |
|
14 |
| -class Demo(): |
15 |
| - |
16 |
| - transform = transforms.Compose([ |
17 |
| - transforms.ToTensor(), |
18 |
| - transforms.Normalize( |
19 |
| - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
20 |
| - ]) |
| 13 | +class Predictor: |
21 | 14 |
|
22 | 15 | def __init__(self, input_size, weight=None):
|
23 |
| - self.input_size = input_size |
24 |
| - self.num_class = 5 |
25 |
| - self.model = self.load_model(weight) |
26 |
| - self.cmap = self.create_camp() |
27 |
| - |
28 |
| - def create_camp(self): |
29 |
| - return label_colormap(self.num_class + 1)[1:] |
30 |
| - |
31 |
| - @timeit |
32 |
| - def load_model(self, weight): |
33 |
| - model = nn.DataParallel( |
34 |
| - VggFCN(num_classes=5, input_size=self.input_size, |
35 |
| - pretrained=False)).cuda() |
36 |
| - model.load_state_dict(torch.load(weight)) |
37 |
| - return model |
38 |
| - |
39 |
| - @timeit |
| 16 | + self.model = self.build_model(weight) |
| 17 | + self.colorizer = onegan.extension.Colorizer( |
| 18 | + colors=[ |
| 19 | + [249, 69, 93], [255, 229, 170], [144, 206, 181], |
| 20 | + [81, 81, 119], [241, 247, 210]]) |
| 21 | + self.transform = T.Compose([ |
| 22 | + T.Resize(input_size), |
| 23 | + T.ToTensor(), |
| 24 | + T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) |
| 25 | + ]) |
| 26 | + |
| 27 | + def build_model(self, weight_path, joint_class=False): |
| 28 | + model = build_resnet101_FCN(pretrained=False, nb_classes=37, stage_2=True, joint_class=joint_class) |
| 29 | + weight = onegan.utils.export_checkpoint_weight(weight_path) |
| 30 | + model.load_state_dict(weight) |
| 31 | + model.eval() |
| 32 | + return model.cuda() |
| 33 | + |
| 34 | + @onegan.utils.timeit |
40 | 35 | def process(self, raw):
|
41 | 36 |
|
42 |
| - def output_label(output): |
43 |
| - label = np.zeros((*self.input_size, 3)) |
44 |
| - for lbl in range(self.num_class): |
45 |
| - label[output.squeeze() == lbl] = self.cmap[lbl] |
46 |
| - return label |
47 |
| - |
48 |
| - img = cv2.resize(raw, self.input_size) |
49 |
| - batched_img = self.transform(img).unsqueeze(0).cuda() |
50 |
| - |
51 |
| - output = self.model.module.predict(batched_img) |
52 |
| - label = output_label(output) |
| 37 | + def _batched_process(batched_img): |
| 38 | + score, _ = self.model(onegan.utils.to_var(batched_img)) |
| 39 | + _, output = torch.max(score, 1) |
53 | 40 |
|
54 |
| - label = cv2.resize(label, (raw.shape[1], raw.shape[0])) |
| 41 | + image = (batched_img / 2 + .5) |
| 42 | + layout = self.colorizer.apply(output.data.cpu()) |
| 43 | + return image * .6 + layout * .4 |
55 | 44 |
|
56 |
| - return raw.astype(np.float32) / 255 * 0.5 + label |
| 45 | + img = Image.fromarray(raw) |
| 46 | + batched_img = self.transform(img).unsqueeze(0) |
| 47 | + canvas = _batched_process(batched_img) |
| 48 | + result = canvas.squeeze().permute(1, 2, 0).numpy() |
| 49 | + return cv2.resize(result, (raw.shape[1], raw.shape[0])) |
57 | 50 |
|
58 | 51 |
|
59 | 52 | @click.command()
|
60 | 53 | @click.option('--device', default=0)
|
61 |
| -@click.option('--video', default='') |
62 |
| -@click.option('--weight', default='output/weight/vgg_bn_new/24.pth') |
63 |
| -@click.option('--input_size', default=(404, 404), type=(int, int)) |
| 54 | +@click.option('--video', type=click.Path(exists=True)) |
| 55 | +@click.option('--weight', type=click.Path(exists=True)) |
| 56 | +@click.option('--input_size', default=(320, 320), type=(int, int)) |
64 | 57 | def main(device, video, weight, input_size):
|
65 | 58 |
|
66 |
| - demo = Demo(input_size, weight=weight) |
| 59 | + demo = Predictor(input_size, weight=weight) |
67 | 60 |
|
68 | 61 | reader = video if video else device
|
69 | 62 | cap = cv2.VideoCapture(reader)
|
|
0 commit comments