forked from megagonlabs/doduo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcreate_pickle.py
105 lines (70 loc) · 3.35 KB
/
create_pickle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import math
import os
import pickle
import numpy as np
import pandas as pd
with open('data/table_col_type_serialized.pkl', 'rb') as f:
data = pickle.load(f)
EXTEND_PATH = "./extend_col_class_checked_fg.csv"
TOTAL_SCORE_PATH = "./total_score.txt"
SOURCE_PATH = './uploads/'
RESULT_PATH = './result/'
def read_tables(source):
with os.scandir(source) as files:
file_list = []
for file in files:
file_list.append(str(file.name))
return file_list
label_ids = [0] * 255
class_list = list(data['mlb'].classes_)
# class_position = int(np.where(data['mlb'].classes_== 'location.location')[0][0])
files = read_tables(SOURCE_PATH)
# количество таблиц в датафреймах
count_train, count_dev, count_test = 0, 0, 0
# количество таблиц для каждого датафрейма
average_count_train_table = math.ceil(len(files) * 0.6)
average_count_table = math.ceil((len(files) * 0.4) / 2)
for file in range(len(files)):
if count_train < average_count_train_table:
train_file = pd.read_csv(SOURCE_PATH + files[file])
count_train += 1
train = pd.DataFrame({'table_id' : [], 'labels' : [], 'data' : [], 'label_ids' : []})
for i in train_file.values:
for j in range(len(train_file.columns)):
label_ids[int(np.where(data['mlb'].classes_== train_file.columns[j].split('(')[1].strip(')'))[0][0])] = 1
new_train_row = pd.DataFrame([
{'table_id': "1-" + str(j+1), 'labels': train_file.columns[j].split('(')[1].strip(')'), 'data': i[j], 'label_ids': label_ids}
])
train = train.append(new_train_row, ignore_index=True)
label_ids = [0] * 255
elif count_dev < average_count_table:
dev_file = pd.read_csv(SOURCE_PATH + files[file])
count_dev += 1
dev = pd.DataFrame({'table_id' : [], 'labels' : [], 'data' : [], 'label_ids' : []})
for i in dev_file.values:
for j in range(len(dev_file.columns)):
label_ids[int(np.where(data['mlb'].classes_== dev_file.columns[j].split('(')[1].strip(')'))[0][0])] = 1
new_dev_row = pd.DataFrame([
{'table_id': "2-" + str(j+1), 'labels': dev_file.columns[j].split('(')[1].strip(')'), 'data': i[j], 'label_ids': label_ids}
])
dev = dev.append(new_dev_row, ignore_index=True)
label_ids = [0] * 255
elif count_test < average_count_table:
test_file = pd.read_csv(SOURCE_PATH + files[file])
count_test += 1
test = pd.DataFrame({'table_id' : [], 'labels' : [], 'data' : [], 'label_ids' : []})
for i in test_file.values:
for j in range(len(test_file.columns)):
label_ids[int(np.where(data['mlb'].classes_== test_file.columns[j].split('(')[1].strip(')'))[0][0])] = 1
new_test_row = pd.DataFrame([
{'table_id': "3-" + str(j+1), 'labels': test_file.columns[j].split('(')[1].strip(')'), 'data': i[j], 'label_ids': label_ids}
])
test = test.append(new_test_row, ignore_index=True)
label_ids = [0] * 255
result = dict()
result['train'] = train
result['dev'] = dev
result['test'] = test
result['mlb'] = data['mlb']
with open('result.pkl', 'wb') as f:
pickle.dump(result, f)