-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcheckdata.py
59 lines (51 loc) · 1.8 KB
/
checkdata.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from __future__ import print_function
import os
import random
import shutil
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from PIL import Image
from utils import load_model, AverageMeter, accuracy
from PIL import Image
images_origin = np.load('./datasets/cifar_train3_image.npy')
images = np.load('./datasets/train3_CUAT_wideresnet_image.npy')
# labels_origin = np.load('./datasets/cifar_train2_label.npy')
# labels = np.load('./datasets/train2_PGD-8_densenet_label.npy')
# images_origin = images_origin / images_origin.sum(axis=1, keepdims=True)
# merge = (images_origin + images)/2.
# np.save('./cifar_distill_label.npy', merge)
# print(np.argmax(images_origin[0:10], axis=1))
# print(np.argmax(images[0:10], axis=1))
# print(np.argmax(merge[0:10], axis=1))
# print(images_origin[1])
# print(images[1])
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
showlist = [0,1,2,3,4]
for i in showlist:
image_origin = Image.fromarray(images_origin[i])
image = Image.fromarray(images[i])
# if (image == image_origin):
# print("True")
# else:
# print("False")
# print(images[i][0][0])
# #print(images_origin[i][0][0])
# print("!!!")
#print(classes[np.argmax(labels_origin[i])])
#print(classes[np.argmax(labels[i])])
image_origin.save('./show/train3_'+str(i)+'.png')
image.save('./show/UAT_'+str(i)+'.png')
# print(labels[i])
print(images.shape)
# print('origin:', images_origin[0][0][0])
# print('PGD:', images[0][0][0])
# np.save('data.npy', images_merge)
# np.save('label.npy', labels_merge)
#最优解1PGD-d+2PGD-d+3PGD-d+light+w10