-
Notifications
You must be signed in to change notification settings - Fork 33
/
run_trained_model.sh
executable file
·36 lines (33 loc) · 1.22 KB
/
run_trained_model.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
output="runs"
device="cpu"
if [ "$1" == "hs" ]; then
# hs dataset
echo "run trained model for hs"
dataset="data/hs.freq3.pre_suf.unary_closure.bin"
model="model.hs_unary_closure_top20_word128_encoder256_rule128_node64.beam15.adadelta.simple_trans.8e39832.iter5600.npz"
commandline="-decode_max_time_step 350 -rule_embed_dim 128 -node_embed_dim 64"
datatype="hs"
else
# django dataset
echo "run trained model for django"
dataset="data/django.cleaned.dataset.freq5.par_info.refact.space_only.bin"
model="model.django_word128_encoder256_rule128_node64.beam15.adam.simple_trans.no_unary_closure.8e39832.run3.best_acc.npz"
commandline="-rule_embed_dim 128 -node_embed_dim 64"
datatype="django"
fi
# decode the test set and save the nbest decoding results
THEANO_FLAGS="mode=FAST_RUN,device=${device},floatX=float32" python code_gen.py \
-data_type ${datatype} \
-data ${dataset} \
-output_dir ${output} \
-model models/${model} \
${commandline} \
decode \
-saveto ${output}/${model}.decode_results.test.bin
# evaluate the decoding result
python code_gen.py \
-data_type ${datatype} \
-data ${dataset} \
-output_dir ${output} \
evaluate \
-input ${output}/${model}.decode_results.test.bin | tee ${output}/${model}.decode_results.test.log