-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_knn_models.py
70 lines (55 loc) · 2.02 KB
/
train_knn_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
This script is to train the KNN model on the main components
It could runs in the raspbeery pi
"""
# ---------------------------------------------
# 1. Initialization
# ---------------------------------------------
components_path = r"./models/knn/components.csv"
model_folder = r"./models/knn"
print("Import Packages...")
from sklearn.neighbors import KNeighborsRegressor
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import joblib
# ---------------------------------------------
# 2. Prepare data
# ---------------------------------------------
print("\nReading Data...")
components_pd = pd.read_csv(components_path)
X = components_pd.iloc[:, 0].to_numpy().reshape(-1, 1)
Y = components_pd.iloc[:, 1: 3].to_numpy()
print("Original dataset shape")
print("X:", X.shape, " Y:", Y.shape)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.1, random_state=42)
print("Train dataset shape")
print("X:", X_train.shape, " Y:", Y_train.shape)
print("Test dataset shape")
print("X:", X_test.shape, " Y:", Y_test.shape)
# ---------------------------------------------
# 3. Prepare model
# ---------------------------------------------
neighbor_num = 30
neigh_x = KNeighborsRegressor(n_neighbors=neighbor_num)
neigh_y = KNeighborsRegressor(n_neighbors=neighbor_num)
# ---------------------------------------------
# 4. Fit and test models
# ---------------------------------------------
print("\nStart training...")
neigh_x.fit(X_train, Y_train[:, 0])
neigh_y.fit(X_train, Y_train[:, 1])
print("------------------------------------")
print("Scores")
sc1 = neigh_x.score(X_test, Y_test[:, 0])
sc2 = neigh_y.score(X_test, Y_test[:, 1])
print(sc1, sc2)
# ---------------------------------------------
# 5. Save models
# ---------------------------------------------
model_path1 = model_folder + r'/' + 'knn_x.joblib'
model_path2 = model_folder + r'/' + 'knn_y.joblib'
print("Saving in", model_path1)
print("Saving in", model_path2)
joblib.dump(neigh_x, model_path1)
joblib.dump(neigh_y, model_path2)