-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_model.py
240 lines (179 loc) · 7.06 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import os
import cv2
import joblib
import pandas as pd
import numpy as np
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.metrics import classification_report
from sklearn.utils import resample
from sklearn import svm
from imblearn.over_sampling import SMOTE
# RESOURCE PATHS
RESOURCE_DIR = r"./resources" # Root path for all static resources used.
RESOURCE_LABEL_CSV_PATH = "labels.csv" # csv filepath relative to resource dir
# ARTIFACT PATHS
ARTIFACT_IMG_DATA_SAVE_PATH = "imgdata.sav" # a save of the loaded image data in a df.
ARTIFACT_SVM_MODEL_SAVE = "svm.joblib"
ARTIFACT_SVM_TEST_DATA_SAVE = "svm_test_data.joblib"
# GLOBAL PARAMETERS
SEED = 3
CROSS_VALIDATION_FOLDS = 5
# SMOTE
MAX_SAMPLES = 2000
SMOTE_RATIO = 1 # Num_minority/Num_majority after sampling
def _getLabels(label_csv_path):
"""Retrieves class ID to label mappings in a dict.
NOTE: A dict is returned instead of a list as class Ids could be arbitrary. (E.g. non-incremental integers).
Args:
label_csv_path (str): Path to the label CSV file.
Raises:
ValueError: if label CSV does not contain "Name" and "ClassId" fields.
Returns:
dict: Mappings between the class ids and label names
"""
label_data = pd.read_csv(label_csv_path)
if "Name" not in label_data or "ClassId" not in label_data:
raise ValueError("Label CSV must contain the label name and class id columns!")
print(label_data["Name"].to_list())
return {int(class_id): label for (_, class_id, label) in label_data.itertuples()}
def bgr_to_cmyk(bgr):
"""Adapted from stackoverflow.com/questions/69955216"""
bgrdash = bgr.astype(np.float)/255.
K = 1 - np.max(bgrdash, axis=2)
C = (1-bgrdash[...,2] - K)/(1-K)
M = (1-bgrdash[...,1] - K)/(1-K)
Y = (1-bgrdash[...,0] - K)/(1-K)
return (np.dstack((C,M,Y,K)) * 255).astype(np.uint8)
def minmax_scale(arr):
min_value = np.min(arr)
max_value = np.max(arr)
return (arr - min_value) / (max_value - min_value)
def loadImageData(labelMap, dataDir):
"""Loads image data from our resource directory.
Specifically, we are targetting {RESOURCE_DIR}/myData/{CLASS_ID}
Returns:
Dataframe: Dataframe containing the image data.
"""
features_flattened = []
target = []
# Load image data.
for cls_id in labelMap.keys():
print(f"Loading class: {cls_id} - '{labelMap[cls_id]}'")
# Build our features array. (each row is a HOG descriptor of the raw image data.)
class_data_path = os.path.join(dataDir, str(cls_id))
for img in os.listdir(class_data_path):
with open(os.path.join(class_data_path, img), "rb") as f:
img_array_cmyk = bgr_to_cmyk(cv2.imdecode(np.frombuffer(f.read(), dtype=np.uint8), cv2.IMREAD_UNCHANGED))
features = minmax_scale(img_array_cmyk).flatten() # Feature vector
features_flattened.append(features)
target.append(cls_id)
print(f"Loaded {labelMap[cls_id]} successfully")
df = pd.DataFrame(features_flattened)
df["Target"] = target
print("Dataset loaded!")
return df
def getImageData(reset_cache, label_map):
"""Retrieve a DataFrame with image data
Args:
reset_cache (boolean): whether or not to force a load from disk
label_map (dict): A map of id to label
Returns:
DataFrame: image data
"""
# Use cached data df if there is one already saved.
if not reset_cache and os.path.isfile(ARTIFACT_IMG_DATA_SAVE_PATH):
print(f"Using cached data at {ARTIFACT_IMG_DATA_SAVE_PATH}")
df = joblib.load(ARTIFACT_IMG_DATA_SAVE_PATH)
else:
# No cached image data found, we must load it.
try:
print("Loading image data...")
df = loadImageData(label_map)
joblib.dump(df, ARTIFACT_IMG_DATA_SAVE_PATH)
except Exception as e:
# Clean up save file if we couldn't finish.
os.remove(ARTIFACT_IMG_DATA_SAVE_PATH)
raise e
return df
def save_test_data(filename, x_test, y_test):
"""Saves the given test data to disk for later replication.
Args:
filename (str): Filename to save as.
x_test (DataFrame): Features dataframe
y_test (DataFrame): Labels dataframe
"""
test_data = x_test.copy()
test_data["Target"] = y_test
joblib.dump(test_data, filename)
def main():
"""Program entry method.
Returns:
int: 0 for success, failure code otherwise.
"""
print(f"SEED: {SEED}")
# [SETUP: 1/3] Fetch Data
print("Fetching labels...")
labelMap = _getLabels(os.path.join(RESOURCE_DIR, RESOURCE_LABEL_CSV_PATH))
print("Fetching data...")
df = loadImageData(labelMap, os.path.join(RESOURCE_DIR, "myData"))
# Perform SMOTE synthetic data generation
# features, target = get_balanced_data(df, labelMap, MAX_SAMPLES, SMOTE_RATIO)
# Work with raw image data.
features = df.drop(columns="Target", axis=1)
target = df["Target"]
print(features.shape)
print(target.shape)
# UNCOMMENT IF WE NEED TO RUN WITH SMALLER DATASET
# _, x_in, _, y_in = train_test_split(
# features, target, test_size=1, shuffle=True, random_state=SEED, stratify=target,
# )
# x_train, x_test, y_train, y_test = train_test_split(
# x_in, y_in, test_size=0.3, shuffle=True, random_state=SEED, stratify=y_in,
# )
# PROPER TRAIN/TEST SPLIT
x_train, x_test, y_train, y_test = train_test_split(
features, target, test_size=0.3, shuffle=True, random_state=SEED, stratify=target,
)
print(f"{len(x_train)} | {len(x_test)} | {len(y_train)} | {len(y_test)}") # Debug - ignore
print("Fitting SVM")
param_grid = {
"C": np.linspace(0.0001, 15, 15),
"gamma": np.linspace(0.0001, 15, 15),
"kernel": ["poly", "linear"],
}
param_cv = GridSearchCV(
svm.SVC(random_state=SEED),
param_grid,
cv=CROSS_VALIDATION_FOLDS,
verbose=3,
scoring="f1_macro",
n_jobs=os.cpu_count()
)
param_cv.fit(x_train, y_train)
print("Best parameters:")
print(param_cv.best_params_)
# Train full model
print("Training full model...")
svm_clf = svm.SVC(
kernel=param_cv.best_params_["kernel"],
gamma=param_cv.best_params_["gamma"],
C=param_cv.best_params_["C"],
random_state=SEED,
probability=True
)
svm_clf.fit(x_train, y_train)
y_pred = svm_clf.predict(x_test)
print(classification_report(y_test, y_pred))
# [SVM: 2/2] Export model for later replication.
joblib.dump(svm_clf, ARTIFACT_SVM_MODEL_SAVE)
save_test_data(ARTIFACT_SVM_TEST_DATA_SAVE, x_test, y_test)
print("\a") # Ring terminal bell to signal completion
print("TESTING")
df = loadImageData(labelMap, os.path.join(RESOURCE_DIR, "myData-test"))
features = df.drop(columns="Target", axis=1)
target = df["Target"]
pred = svm_clf.predict(features)
print(classification_report(target, pred))
return 0
if __name__ == "__main__":
main()