-
Notifications
You must be signed in to change notification settings - Fork 1
/
predictByCEAL.py~
87 lines (74 loc) · 3.2 KB
/
predictByCEAL.py~
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import numpy as np
import tensorflow as tf
from alexnet import AlexNet
from dataprocess import ImageDataGenerator
from datetime import datetime
from tensorflow.contrib.data import Iterator
from config import *
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '3'
with tf.device('/cpu:0'):
val_data = ImageDataGenerator(rest_file,
mode='inference',
batch_size=test_batch_size,
num_classes=num_classes,
shuffle=False)
iterator = Iterator.from_structure(val_data.data.output_types,
val_data.data.output_shapes)
next_batch = iterator.get_next()
validation_init_op = iterator.make_initializer(val_data.data)
x = tf.placeholder(tf.float32, [test_batch_size, 227, 227, 3])
y = tf.placeholder(tf.float32, [test_batch_size, num_classes])
keep_prob = tf.placeholder(tf.float32)
model = AlexNet(x, keep_prob, num_classes, train_layers)
score = model.fc8_softmax
with tf.name_scope("accuracy"):
index=tf.argmax(score, 1)
correct_pred = tf.equal(tf.argmax(score, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
saver = tf.train.Saver()
val_batches_per_epoch = int(np.floor(val_data.data_size / test_batch_size))
low_conf_outlines = {}
high_conf_outlines = list()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(checkpoint_path)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
saver.restore(sess, os.path.join(checkpoint_path, ckpt_name))
print "load model successfully!!"
print("{} Start CEAL...".format(datetime.now()))
sess.run(validation_init_op)
test_acc = 0.
test_count = 0
for _ in range(val_batches_per_epoch):
filename_batch, img_batch, label_batch = sess.run(next_batch)
pred_index, pred_score, acc = sess.run([index, score, accuracy], feed_dict={x: img_batch,
y: label_batch,
keep_prob: 1.})
for i in xrange(test_batch_size):
max_pred_conf = np.max( pred_score[i] )
if max_pred_conf < 0.1:
low_conf_outlines[max_pred_conf] = filename_batch[i] + ' ' + str( np.argmax( label_batch[i] ) ) + '\n'
elif max_pred_conf > 0.9:
high_conf_outlines.append( filename_batch[i] + ' ' + str(pred_index[i])+'\n')
test_acc += acc
test_count += 1
test_acc /= test_count
print("{} Pseudo-labeling Accuracy = {:.4f}".format(datetime.now(), test_acc))
if os.path.exists(pred_file):
os.remove(pred_file)
f=open(pred_file,'a')
for line in high_conf_outlines:
f.write( line )
print 'Pseudo-labeled Image Number' + str( len( high_conf_outlines ) )
sorted( low_conf_outlines )
count = 0 ;
for key, val in low_conf_outlines.items():
if count == selectALNUM:
break
count += 1
f.write( val )
f.close()
print 'Finish CEAL!'