-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
345 lines (274 loc) · 10.8 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
from matplotlib import pyplot as plt # 展示图片
import numpy as np # 数值处理
import cv2 # opencv库
from sklearn.linear_model import LinearRegression, Ridge, Lasso # 回归分析
def read_image(img_path):
"""
读取图片,图片是以 np.array 类型存储
:param img_path: 图片的路径以及名称
:return: img np.array 类型存储
"""
# 读取图片
img = cv2.imread(img_path)
# 如果图片是三通道,采用 matplotlib 展示图像时需要先转换通道
if len(img.shape) == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def plot_image(image, image_title, is_axis=False):
"""
展示图像
:param image: 展示的图像,一般是 np.array 类型
:param image_title: 展示图像的名称
:param is_axis: 是否需要关闭坐标轴,默认展示坐标轴
:return:
"""
# 展示图片
plt.imshow(image)
# 关闭坐标轴,默认关闭
if not is_axis:
plt.axis('off')
# 展示受损图片的名称
plt.title(image_title)
# 展示图片
plt.show()
def save_image(filename, image):
"""
将np.ndarray 图像矩阵保存为一张 png 或 jpg 等格式的图片
:param filename: 图片保存路径及图片名称和格式
:param image: 图像矩阵,一般为np.array
:return:
"""
# np.copy() 函数创建一个副本。
# 对副本数据进行修改,不会影响到原始数据,它们物理内存不在同一位置。
img = np.copy(image)
# 从给定数组的形状中删除一维的条目
img = img.squeeze()
# 将图片数据存储类型改为 np.uint8
if img.dtype == np.double:
# 若img数据存储类型是 np.double ,则转化为 np.uint8 形式
img = img * np.iinfo(np.uint8).max
# 转换图片数组数据类型
img = img.astype(np.uint8)
# 将 RGB 方式转换为 BGR 方式
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
# 生成图片
cv2.imwrite(filename, img)
def normalization(image):
"""
将数据线性归一化
:param image: 图片矩阵,一般是np.array 类型
:return: 将归一化后的数据,在(0,1)之间
"""
# 获取图片数据类型对象的最大值和最小值
info = np.iinfo(image.dtype)
# 图像数组数据放缩在 0-1 之间
return image.astype(np.double) / info.max
def noise_mask_image(img, noise_ratio=[0.8,0.4,0.6]):
"""
根据题目要求生成受损图片
:param img: cv2 读取图片,而且通道数顺序为 RGB
:param noise_ratio: 噪声比率,类型是 List,,内容:[r 上的噪声比率,g 上的噪声比率,b 上的噪声比率]
默认值分别是 [0.8,0.4,0.6]
:return: noise_img 受损图片, 图像矩阵值 0-1 之间,数据类型为 np.array,
数据类型对象 (dtype): np.double, 图像形状:(height,width,channel),通道(channel) 顺序为RGB
"""
# 受损图片初始化
noise_img = None
# -------------实现受损图像答题区域-----------------
import random
noise_img = np.copy(img)
for i in range(3):
for j in range(img.shape[0]):
mask = list(range(img.shape[1]))
mask = random.sample(mask, int(img.shape[1]*noise_ratio[i]))
for k in range(img.shape[1]):
if k in mask:
noise_img[j,k,i] = 0
# -----------------------------------------------
return noise_img
def get_noise_mask(noise_img):
"""
获取噪声图像,一般为 np.array
:param noise_img: 带有噪声的图片
:return: 噪声图像矩阵
"""
# 将图片数据矩阵只包含 0和1,如果不能等于 0 则就是 1。
return np.array(noise_img != 0, dtype='double')
def compute_error(res_img, img):
"""
计算恢复图像 res_img 与原始图像 img 的 2-范数
:param res_img:恢复图像
:param img:原始图像
:return: 恢复图像 res_img 与原始图像 img 的2-范数
"""
# 初始化
error = 0.0
# 将图像矩阵转换成为np.narray
res_img = np.array(res_img)
img = np.array(img)
# 如果2个图像的形状不一致,则打印出错误结果,返回值为 None
if res_img.shape != img.shape:
print("shape error res_img.shape and img.shape %s != %s" % (res_img.shape, img.shape))
return None
# 计算图像矩阵之间的评估误差
error = np.sqrt(np.sum(np.power(res_img - img, 2)))
return round(error,3)
# 计算平面二维向量的 2-范数值
img0 = [1, 0]
img1 = [0, 1]
print("平面向量的评估误差:", compute_error(img0, img1))
from skimage.measure import compare_ssim as ssim
from scipy import spatial
def calc_ssim(img, img_noise):
"""
计算图片的结构相似度
:param img: 原始图片, 数据类型为 ndarray, shape 为[长, 宽, 3]
:param img_noise: 噪声图片或恢复后的图片,
数据类型为 ndarray, shape 为[长, 宽, 3]
:return:
"""
return ssim(img, img_noise,
multichannel=True,
data_range=img_noise.max() - img_noise.min())
def calc_csim(img, img_noise):
"""
计算图片的 cos 相似度
:param img: 原始图片, 数据类型为 ndarray, shape 为[长, 宽, 3]
:param img_noise: 噪声图片或恢复后的图片,
数据类型为 ndarray, shape 为[长, 宽, 3]
:return:
"""
img = img.reshape(-1)
img_noise = img_noise.reshape(-1)
return 1 - spatial.distance.cosine(img, img_noise)
from PIL import Image
import numpy as np
def read_img(path):
img = Image.open(path)
img = img.resize((150,150))
img = np.asarray(img, dtype="uint8")
# 获取图片数据类型对象的最大值和最小值
info = np.iinfo(img.dtype)
# 图像数组数据放缩在 0-1 之间
return img.astype(np.double) / info.max
img = read_img('A.png')
noise = np.ones_like(img) * 0.2 * (img.max() - img.min())
noise[np.random.random(size=noise.shape) > 0.5] *= -1
img_noise = img + abs(noise)
print('相同图片的 SSIM 相似度: ', calc_ssim(img, img))
print('相同图片的 Cosine 相似度: ', calc_csim(img, img))
print('与噪声图片的 SSIM 相似度: ', calc_ssim(img, img_noise))
print('与噪声图片的 Cosine 相似度: ', calc_csim(img, img_noise))
def restore_image(noise_img, size=4):
"""
使用 你最擅长的算法模型 进行图像恢复。
:param noise_img: 一个受损的图像
:param size: 输入区域半径,长宽是以 size*size 方形区域获取区域, 默认是 4
:return: res_img 恢复后的图片,图像矩阵值 0-1 之间,数据类型为 np.array,
数据类型对象 (dtype): np.double, 图像形状:(height,width,channel), 通道(channel) 顺序为RGB
"""
# 恢复图片初始化,首先 copy 受损图片,然后预测噪声点的坐标后作为返回值。
res_img = np.copy(noise_img)
# 获取噪声图像
noise_mask = get_noise_mask(noise_img)
# -------------实现图像恢复代码答题区域----------------------------
for i in range(noise_mask.shape[0]):
for j in range(noise_mask.shape[1]):
for k in range(noise_mask.shape[2]):
if noise_mask[i,j,k] == 0:
sc = 1
listx = get_window_small(res_img,noise_mask,i,j,k)
if len(listx) != 0:
res_img[i,j,k] = listx[len(listx)//2]
else:
while(len(listx) == 0):
listx = get_window(res_img,noise_mask,sc,i,j,k)
sc = sc+1
if sc > 4:
res_img[i,j,k] = np.mean(listx)
else:
res_img[i,j,k] = listx[len(listx)//2]
# ---------------------------------------------------------------
return res_img
def get_window(res_img,noise_mask,sc,i,j,k):
listx = []
if i-sc >= 0:
starti = i-sc
else:
starti = 0
if j+1 <= res_img.shape[1]-1 and noise_mask[0,j+1,k] !=0:
listx.append(res_img[0,j+1,k])
if j-1 >=0 and noise_mask[0,j-1,k] !=0:
listx.append(res_img[0,j-1,k])
if i+sc <= res_img.shape[0]-1:
endi = i+sc
else:
endi = res_img.shape[0]-1
if j+1 <= res_img.shape[1]-1 and noise_mask[endi,j+1,k] !=0:
listx.append(res_img[endi,j+1,k])
if j-1 >=0 and noise_mask[endi,j-1,k] !=0:
listx.append(res_img[endi,j-1,k])
if j+sc <= res_img.shape[1]-1:
endj = j+sc
else:
endj = res_img.shape[1]-1
if i+1 <= res_img.shape[0]-1 and noise_mask[i+1,endj,k] !=0:
listx.append(res_img[i+1,endj,k])
if i-1 >=0 and noise_mask[i-1,endj,k] !=0:
listx.append(res_img[i-1,endj,k])
if j-sc >= 0:
startj = j-sc
else:
startj = 0
if i+1 <= res_img.shape[0]-1 and noise_mask[i+1,0,k] !=0:
listx.append(res_img[i+1,0,k])
if i-1 >=0 and noise_mask[i-1,0,k] !=0:
listx.append(res_img[i-1,0,k])
for m in range(starti,endi+1):
for n in range(startj,endj+1):
if noise_mask[m,n,k] != 0:
listx.append(res_img[m,n,k])
listx.sort()
return listx
def get_window_small(res_img,noise_mask,i,j,k):
listx = []
sc = 1
if i-sc >= 0 and noise_mask[i-1,j,k]!=0:
listx.append(res_img[i-1,j,k])
if i+sc <= res_img.shape[0]-1 and noise_mask[i+1,j,k]!=0:
listx.append(res_img[i+1,j,k])
if j+sc <= res_img.shape[1]-1 and noise_mask[i,j+1,k]!=0:
listx.append(res_img[i,j+1,k])
if j-sc >= 0 and noise_mask[i,j-1,k]!=0:
listx.append(res_img[i,j-1,k])
listx.sort()
return listx
# 原始图片
# 加载图片的路径和名称
img_path = 'A.png'
# 读取原始图片
img = read_image(img_path)
# 展示原始图片
plot_image(image=img, image_title="original image")
# 生成受损图片
# 图像数据归一化
nor_img = normalization(img)
noise_ratio = [0.4, 0.6, 0.8]
# 生成受损图片
noise_img = noise_mask_image(nor_img, noise_ratio)
if noise_img is not None:
# 展示受损图片
plot_image(image=noise_img, image_title="the noise_ratio = %s of original image"%noise_ratio)
# 恢复图片
res_img = restore_image(noise_img)
# 计算恢复图片与原始图片的误差
print("恢复图片与原始图片的评估误差: ", compute_error(res_img, nor_img))
print("恢复图片与原始图片的 SSIM 相似度: ", calc_ssim(res_img, nor_img))
print("恢复图片与原始图片的 Cosine 相似度: ", calc_csim(res_img, nor_img))
# 展示恢复图片
plot_image(image=res_img, image_title="restore image")
# 保存恢复图片
save_image('res_' + img_path, res_img)
else:
# 未生成受损图片
print("返回值是 None, 请生成受损图片并返回!")