From c68c2f1dabb973b70565cfd8b2c42dc4cd3173ec Mon Sep 17 00:00:00 2001 From: Junzhuo Du Date: Wed, 2 Jan 2019 00:07:42 -0800 Subject: [PATCH] Update train_svm.py --- Exercise-3/sensor_stick/scripts/train_svm.py | 31 ++++++++++---------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/Exercise-3/sensor_stick/scripts/train_svm.py b/Exercise-3/sensor_stick/scripts/train_svm.py index b7af250..6e1f0b6 100755 --- a/Exercise-3/sensor_stick/scripts/train_svm.py +++ b/Exercise-3/sensor_stick/scripts/train_svm.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt from sklearn import svm from sklearn.preprocessing import LabelEncoder, StandardScaler -from sklearn import cross_validation +from sklearn.model_selection import cross_val_score, cross_val_predict, KFold from sklearn import metrics def plot_confusion_matrix(cm, classes, @@ -65,27 +65,26 @@ def plot_confusion_matrix(cm, classes, clf = svm.SVC(kernel='linear') # Set up 5-fold cross-validation -kf = cross_validation.KFold(len(X_train), - n_folds=5, - shuffle=True, - random_state=1) +kf = KFold(n_splits=5, + shuffle=True, + random_state=1) # Perform cross-validation -scores = cross_validation.cross_val_score(cv=kf, - estimator=clf, - X=X_train, - y=y_train, - scoring='accuracy' - ) +scores = cross_val_score(estimator=clf, + X=X_train, + y=y_train, + scoring='accuracy', + cv=kf + ) print('Scores: ' + str(scores)) print('Accuracy: %0.2f (+/- %0.2f)' % (scores.mean(), 2*scores.std())) # Gather predictions -predictions = cross_validation.cross_val_predict(cv=kf, - estimator=clf, - X=X_train, - y=y_train - ) +predictions = cross_val_predict(estimator=clf, + X=X_train, + y=y_train, + cv=kf + ) accuracy_score = metrics.accuracy_score(y_train, predictions) print('accuracy score: '+str(accuracy_score))