diff --git a/app/services/gaze_tracker.py b/app/services/gaze_tracker.py index 7d1f7ce..2ea0eb4 100644 --- a/app/services/gaze_tracker.py +++ b/app/services/gaze_tracker.py @@ -110,8 +110,7 @@ def predict(data, k, model_X, model_Y): if ( model_X == "Linear Regression" - or model_X == "Elastic Net" - or model_X == "Support Vector Regressor" + or model_X == 'Random Forest Regressor' ): model = models[model_X] @@ -152,11 +151,10 @@ def predict(data, k, model_X, model_Y): if ( model_Y == "Linear Regression" - or model_Y == "Elastic Net" - or model_Y == "Support Vector Regressor" + or model_Y=="Random Forest Regressor" ): model = models[model_Y] - + # Fit the model and make predictions model.fit(X_train_y, y_train_y) y_pred_y = model.predict(X_test_y)