-
Notifications
You must be signed in to change notification settings - Fork 103
/
cloth-mask.py
124 lines (100 loc) · 3.91 KB
/
cloth-mask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
from PIL import Image
import numpy as np
from collections import OrderedDict
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from networks.u2net import U2NET
device = 'cuda'
image_dir = '/content/inputs/test/cloth'
result_dir = '/content/inputs/test/cloth-mask'
checkpoint_path = 'cloth_segm_u2net_latest.pth'
def load_checkpoint_mgpu(model, checkpoint_path):
if not os.path.exists(checkpoint_path):
print("----No checkpoints at given path----")
return
model_state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
new_state_dict = OrderedDict()
for k, v in model_state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
print("----checkpoints loaded from path: {}----".format(checkpoint_path))
return model
class Normalize_image(object):
"""Normalize given tensor into given mean and standard dev
Args:
mean (float): Desired mean to substract from tensors
std (float): Desired std to divide from tensors
"""
def __init__(self, mean, std):
assert isinstance(mean, (float))
if isinstance(mean, float):
self.mean = mean
if isinstance(std, float):
self.std = std
self.normalize_1 = transforms.Normalize(self.mean, self.std)
self.normalize_3 = transforms.Normalize([self.mean] * 3, [self.std] * 3)
self.normalize_18 = transforms.Normalize([self.mean] * 18, [self.std] * 18)
def __call__(self, image_tensor):
if image_tensor.shape[0] == 1:
return self.normalize_1(image_tensor)
elif image_tensor.shape[0] == 3:
return self.normalize_3(image_tensor)
elif image_tensor.shape[0] == 18:
return self.normalize_18(image_tensor)
else:
assert "Please set proper channels! Normlization implemented only for 1, 3 and 18"
def get_palette(num_cls):
""" Returns the color map for visualizing the segmentation mask.
Args:
num_cls: Number of classes
Returns:
The color map
"""
n = num_cls
palette = [0] * (n * 3)
for j in range(0, n):
lab = j
palette[j * 3 + 0] = 0
palette[j * 3 + 1] = 0
palette[j * 3 + 2] = 0
i = 0
while lab:
palette[j * 3 + 0] = 255
palette[j * 3 + 1] = 255
palette[j * 3 + 2] = 255
# palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
# palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
# palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
i += 1
lab >>= 3
return palette
transforms_list = []
transforms_list += [transforms.ToTensor()]
transforms_list += [Normalize_image(0.5, 0.5)]
transform_rgb = transforms.Compose(transforms_list)
net = U2NET(in_ch=3, out_ch=4)
net = load_checkpoint_mgpu(net, checkpoint_path)
net = net.to(device)
net = net.eval()
palette = get_palette(4)
images_list = sorted(os.listdir(image_dir))
for image_name in images_list:
img = Image.open(os.path.join(image_dir, image_name)).convert('RGB')
img_size = img.size
img = img.resize((768, 768), Image.BICUBIC)
image_tensor = transform_rgb(img)
image_tensor = torch.unsqueeze(image_tensor, 0)
output_tensor = net(image_tensor.to(device))
output_tensor = F.log_softmax(output_tensor[0], dim=1)
output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
output_tensor = torch.squeeze(output_tensor, dim=0)
output_tensor = torch.squeeze(output_tensor, dim=0)
output_arr = output_tensor.cpu().numpy()
output_img = Image.fromarray(output_arr.astype('uint8'), mode='L')
output_img = output_img.resize(img_size, Image.BICUBIC)
output_img.putpalette(palette)
output_img = output_img.convert('L')
output_img.save(os.path.join(result_dir, image_name[:-4]+'.jpg'))