-
Notifications
You must be signed in to change notification settings - Fork 1
/
predict.py
94 lines (84 loc) · 3.07 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
__author__ = 'dudevil'
import sys
import os.path
import argparse
import numpy as np
import pandas as pd
import theano
import theano.tensor as T
from lasagne.utils import floatX
from utils import load_network, get_predictions
from load_dataset import DataLoader
BATCH_SIZE = 64
IMAGE_SIZE = 128
def save_submission(predictions, filenames, n=1):
#assert(len(predictions) == len(filenames))
names = [os.path.splitext(os.path.basename(image))[0] for image in filenames]
dfr = pd.DataFrame(predictions[:len(filenames)], index=names)
dfr.to_csv(os.path.join("data", "submissions", "submission_%s.csv" % n))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-n",
"--network",
type=str,
default=os.path.join("data", "tidy", "net.pickle"),
help="Path to the pickled network file")
parser.add_argument('--proba', dest='proba', action='store_const',
const=1, default=0,
help='flag to predict probabilities rather than hard targets')
args = parser.parse_args()
netfile = args.network
print("Loading saved network...")
if not os.path.isfile(netfile):
print("No such file: %s" % netfile)
exit()
try:
network, output = load_network(netfile)
except Exception, e:
print("Could not load network: %s" % e)
print("Loading test dataset...")
# load test data chunk
dl = DataLoader(image_size=IMAGE_SIZE,
normalize=True,
batch_size=64,
parallel=False,
train_path="train/trimmed",
test_path=os.path.join("test", "trimmed"))
test_filenames = dl.test_images
n_predictions = len(test_filenames)
print("Compiling theano functions...")
# set up symbolic variables
X = T.tensor4('X')
X_batch = T.tensor4('X_batch')
batch_index = T.iscalar('batch_index')
#pred = T.iround(output.get_output(X_batch, deterministic=True))
if args.proba:
predict = theano.function(
[theano.Param(X_batch)],
output.get_output(X_batch, deterministic=True),
givens={
X: X_batch
},
)
else:
predict = theano.function(
[theano.Param(X_batch)],
T.gt(output.get_output(X_batch, deterministic=True), 0.5),
givens={
X: X_batch
},
)
print("Predicting...")
predictions = []
for test_chunk in dl.test_gen():
preds = predict(test_chunk)
predictions.append(preds)
sys.stdout.write("progress: %d %%\r" % (len(predictions) * BATCH_SIZE * 100. / n_predictions))
sys.stdout.flush()
print("Saving predictions")
predictions = np.vstack(predictions)
if not args.proba:
predictions = get_predictions(predictions)
save_submission(predictions.flatten(), test_filenames, n=23)
else:
save_submission(predictions, test_filenames, n='19')