-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathgenerate_image.py
154 lines (128 loc) · 5.72 KB
/
generate_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
使用预训练的网络pickle生成图像
Copyright © [2023] [浙江大学计算机视觉人工智能国家重点实验室·Zhejiang HangZhou]. All Rights Reserved.
"""
import cv2
import pyspng
import glob
import os
import re
import random
from typing import List, Optional
import click
import dnnlib
import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
import legacy
from datasets.mask_generator_512 import RandomMask
from networks.mat import Generator
def num_range(s: str) -> List[int]:
range_re = re.compile(r'^(\d+)-(\d+)$')
m = range_re.match(s)
if m:
return list(range(int(m.group(1)), int(m.group(2))+1))
vals = s.split(',')
return [int(x) for x in vals]
def copy_params_and_buffers(src_module, dst_module, require_all=False):
assert isinstance(src_module, torch.nn.Module)
assert isinstance(dst_module, torch.nn.Module)
src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
for name, tensor in named_params_and_buffers(dst_module):
assert (name in src_tensors) or (not require_all)
if name in src_tensors:
tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
def params_and_buffers(module):
assert isinstance(module, torch.nn.Module)
return list(module.parameters()) + list(module.buffers())
def named_params_and_buffers(module):
assert isinstance(module, torch.nn.Module)
return list(module.named_parameters()) + list(module.named_buffers())
@click.command()
@click.pass_context
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
@click.option('--dpath', help='the path of the input image', required=True)
@click.option('--mpath', help='the path of the mask')
@click.option('--resolution', type=int, help='resolution of input image', default=512, show_default=True)
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
def generate_images(
ctx: click.Context,
network_pkl: str,
dpath: str,
mpath: Optional[str],
resolution: int,
truncation_psi: float,
noise_mode: str,
outdir: str,
):
"""
Generate images using pretrained network pickle.
"""
seed = 240 # 固定一个随机种子
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
print(f'Loading data from: {dpath}')
img_list = sorted(glob.glob(dpath + '/*.png') + glob.glob(dpath + '/*.jpg'))
if mpath is not None:
print(f'Loading mask from: {mpath}')
mask_list = sorted(glob.glob(mpath + '/*.png') + glob.glob(mpath + '/*.jpg'))
print("当前目录下有图片:",len(img_list),"张\n有掩膜:",len(mask_list),"张")
if len(img_list) == len(mask_list):
print("结果:数量匹配")
assert len(img_list) == len(mask_list), '图片和掩膜数量不匹配'
print(f'Loading networks from: {network_pkl}')
device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as f:
G_saved = legacy.load_network_pkl(f)['G_ema'].to(device).eval().requires_grad_(False) # type: ignore
net_res = 512 if resolution > 512 else resolution
G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=net_res, img_channels=3).to(device).eval().requires_grad_(False)
copy_params_and_buffers(G_saved, G, require_all=True)
os.makedirs(outdir, exist_ok=True)
# no Labels.
label = torch.zeros([1, G.c_dim], device=device)
def read_image(image_path):
with open(image_path, 'rb') as f:
if pyspng is not None and image_path.endswith('.png'):
image = pyspng.load(f.read())
else:
image = np.array(PIL.Image.open(f))
if image.ndim == 2:
image = image[:, :, np.newaxis] # HW => HWC
image = np.repeat(image, 3, axis=2)
image = image.transpose(2, 0, 1) # HWC => CHW
image = image[:3]
return image
def to_image(image, lo, hi):
image = np.asarray(image, dtype=np.float32)
image = (image - lo) * (255 / (hi - lo))
image = np.rint(image).clip(0, 255).astype(np.uint8)
image = np.transpose(image, (1, 2, 0))
if image.shape[2] == 1:
image = np.repeat(image, 3, axis=2)
return image
if resolution != 512:
noise_mode = 'random'
with torch.no_grad():
for i, ipath in enumerate(img_list):
iname = os.path.basename(ipath).replace('.jpg', '.png')
print(f'Prcessing: {iname}')
image = read_image(ipath)
image = (torch.from_numpy(image).float().to(device) / 127.5 - 1).unsqueeze(0)
if mpath is not None:
mask = cv2.imread(mask_list[i], cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.0
mask = torch.from_numpy(mask).float().to(device).unsqueeze(0).unsqueeze(0)
else:
mask = RandomMask(resolution) # 使用hole_range调整遮罩比率
mask = torch.from_numpy(mask).float().to(device).unsqueeze(0)
z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(device)
output = G(image, mask, z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
output = (output.permute(0, 2, 3, 1) * 127.5 + 127.5).round().clamp(0, 255).to(torch.uint8)
output = output[0].cpu().numpy()
PIL.Image.fromarray(output, 'RGB').save(f'{outdir}/{iname}')
if __name__ == "__main__":
generate_images()