-
Notifications
You must be signed in to change notification settings - Fork 0
/
load.py
68 lines (52 loc) · 2.81 KB
/
load.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import h5py
from classification import predict, submit_predictions
from config import __DIRNAME_INPUT__, __FONTS_DICT__
from preprocess import process_db
from training import create_model, train_model
from utils import split_list_by_percentage
def load_training_and_validation(ds_filename='train.h5'):
print("Stage: load training and validation database")
path_input = __DIRNAME_INPUT__ + ds_filename
with h5py.File(path_input, 'r') as db:
# process training database
list_words, list_chars_fonts, list_char_images = process_db(
db, train_mode=True)
# split data to validation and training
print("Stage: split data to validation and training")
amount_training_data = 80 # percentage value # %
list_words_training, list_words_validation = split_list_by_percentage(
list_words, amount_training_data)
# get num chars in training set
num_chars_in_trainning = 0
for word in list_words_training:
for char in word:
num_chars_in_trainning += 1
list_char_images_training, list_char_images_validation = list_char_images[
:num_chars_in_trainning], list_char_images[num_chars_in_trainning:]
list_chars_fonts_training, list_chars_fonts_validation = list_chars_fonts[
:num_chars_in_trainning], list_chars_fonts[num_chars_in_trainning:]
print("list_words_validation: ", len(list_words_validation))
# create model for font detection
new_model = create_model(len(__FONTS_DICT__))
# train the model
print("Stage: training")
trained_model = train_model(new_model, list_char_images_training, list_chars_fonts_training,
list_char_images_validation, list_chars_fonts_validation)
# predict fonts from images on validation data
print("Stage: making predictions on validation data")
fonts_prediction_ids = predict(
list_words_validation, list_char_images_validation, list_chars_fonts_validation)
def predict_on_test_data(classification_model_file_name, ds_filename='test.h5'):
print("Stage: load test database")
path_input = __DIRNAME_INPUT__ + ds_filename
with h5py.File(path_input, 'r') as db:
# process test database
list_words, list_chars_fonts, list_char_images = process_db(
db, train_mode=False)
# predict fonts from images on test data
print("Stage: making predictions on test data")
fonts_prediction_ids = predict(
list_words, list_char_images, list_chars_fonts, classification_model_file_name)
# submit predictions
print("Stage: save fonts predictions as .csv file")
submit_predictions(fonts_prediction_ids)