-
Notifications
You must be signed in to change notification settings - Fork 82
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
stages of training arima : - intially just on date and close - combining with lstm - updating data to include days , lags and other features and include while training Current RMSE : 4.67
- Loading branch information
1 parent
0f79d65
commit 40a39b4
Showing
3 changed files
with
1,440 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,337 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Intial code\n", | ||
"for reference purposes" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# import pandas as pd\n", | ||
"# import numpy as np\n", | ||
"# from sklearn.metrics import mean_squared_error\n", | ||
"# from sklearn.preprocessing import StandardScaler\n", | ||
"# from statsmodels.tsa.statespace.sarimax import SARIMAX\n", | ||
"# from statsmodels.tools.sm_exceptions import ConvergenceWarning\n", | ||
"# import warnings\n", | ||
"\n", | ||
"# # Ignore convergence warnings\n", | ||
"# warnings.simplefilter(\"ignore\")\n", | ||
"\n", | ||
"# # Load dataset with parsed dates\n", | ||
"# data = pd.read_csv('../Data/SBI Train data.csv', parse_dates=['Date'], dayfirst=True)\n", | ||
"\n", | ||
"# # Set the index to the Date column\n", | ||
"# data.set_index('Date', inplace=True)\n", | ||
"# # data = data.asfreq('D')\n", | ||
"# # Feature Engineering: Add day of week and month\n", | ||
"# data['day_of_week'] = data.index.dayofweek\n", | ||
"# data['month'] = data.index.month\n", | ||
"\n", | ||
"# # Add lagged value of the Close price and moving averages\n", | ||
"# data['lagged_close'] = data['Close'].shift(1) \n", | ||
"# data['moving_avg_3'] = data['Close'].rolling(window=3).mean()\n", | ||
"# data['moving_avg_7'] = data['Close'].rolling(window=7).mean() # New: 7-day moving average for long-term trend\n", | ||
"\n", | ||
"# # Add Volume as a feature (scaling might help)\n", | ||
"# data['volume'] = data['Volume']\n", | ||
"\n", | ||
"# # Drop rows with NaN values\n", | ||
"# data.dropna(inplace=True)\n", | ||
"\n", | ||
"# # Standardize the features (important for scaling)\n", | ||
"# scaler = StandardScaler()\n", | ||
"# exog_features = ['day_of_week', 'month', 'lagged_close', 'moving_avg_3', 'moving_avg_7', 'volume']\n", | ||
"# data[exog_features] = scaler.fit_transform(data[exog_features])\n", | ||
"\n", | ||
"# # Split the data into training and testing sets\n", | ||
"# train_size = int(len(data) * 0.8)\n", | ||
"# train, test = data.iloc[:train_size], data.iloc[train_size:]\n", | ||
"\n", | ||
"# # Tune SARIMAX hyperparameters (ARIMA order (p, d, q))\n", | ||
"# order = (2, 1, 2) # Consider using AIC/BIC for finding optimal order\n", | ||
"# seasonal_order = (1, 1, 1, 12) # Adding seasonality with monthly frequency\n", | ||
"\n", | ||
"# # Fit the SARIMAX model\n", | ||
"# try:\n", | ||
"# model = SARIMAX(train['Close'], \n", | ||
"# exog=train[exog_features],\n", | ||
"# order=order,\n", | ||
"# seasonal_order=seasonal_order)\n", | ||
"# model_fit = model.fit(disp=False)\n", | ||
"# except ConvergenceWarning as e:\n", | ||
"# print(f\"Convergence warning: {e}\")\n", | ||
"# except Exception as e:\n", | ||
"# print(f\"Error: {e}\")\n", | ||
"\n", | ||
"# # Forecasting\n", | ||
"# forecast = model_fit.forecast(steps=len(test), exog=test[exog_features])\n", | ||
"\n", | ||
"# # Calculate RMSE for forecast\n", | ||
"# rmse_arimax = np.sqrt(mean_squared_error(test['Close'], forecast))\n", | ||
"# print(f\"Improved ARIMAX Model RMSE: {rmse_arimax}\")\n", | ||
"\n", | ||
"# test_prices = [i for i in test['Close']]\n", | ||
"# # Check residuals diagnostics (optional)\n", | ||
"# residuals = test_prices - forecast\n", | ||
"# print(\"Mean of residuals:\", residuals.mean())\n", | ||
"# print(\"Standard deviation of residuals:\", residuals.std())\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### immporting necessary libraries" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import pandas as pd\n", | ||
"import numpy as np\n", | ||
"import pickle\n", | ||
"from sklearn.metrics import mean_squared_error\n", | ||
"from sklearn.preprocessing import StandardScaler\n", | ||
"from statsmodels.tsa.statespace.sarimax import SARIMAX\n", | ||
"from statsmodels.tools.sm_exceptions import ConvergenceWarning\n", | ||
"import warnings" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Udating features to dataset for proper time-series analysis" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"# Ignore convergence warnings\n", | ||
"warnings.simplefilter(\"ignore\", ConvergenceWarning)\n", | ||
"\n", | ||
"# Load training dataset with parsed dates\n", | ||
"train_data = pd.read_csv('../Data/SBI Train data.csv', parse_dates=['Date'], dayfirst=True)\n", | ||
"\n", | ||
"# Set the index to the Date column\n", | ||
"train_data.set_index('Date', inplace=True)\n", | ||
"\n", | ||
"# Feature Engineering: Add day of week and month\n", | ||
"train_data['day_of_week'] = train_data.index.dayofweek\n", | ||
"train_data['month'] = train_data.index.month\n", | ||
"\n", | ||
"# Add lagged value of the Close price and moving averages\n", | ||
"train_data['lagged_close'] = train_data['Close'].shift(1)\n", | ||
"train_data['moving_avg_3'] = train_data['Close'].rolling(window=3).mean()\n", | ||
"train_data['moving_avg_7'] = train_data['Close'].rolling(window=7).mean()\n", | ||
"\n", | ||
"# Add Volume as a feature (scaling might help)\n", | ||
"train_data['volume'] = train_data['Volume']\n", | ||
"\n", | ||
"# Drop rows with NaN values after applying the rolling window and lagging\n", | ||
"train_data.dropna(inplace=True)\n", | ||
"\n", | ||
"# Standardize the features\n", | ||
"scaler = StandardScaler()\n", | ||
"exog_features = ['day_of_week', 'month', 'lagged_close', 'moving_avg_3', 'moving_avg_7', 'volume']\n", | ||
"train_data[exog_features] = scaler.fit_transform(train_data[exog_features])\n", | ||
"\n", | ||
"# Split the data into training and testing sets\n", | ||
"train_size = int(len(train_data) * 0.8)\n", | ||
"train, validation = train_data.iloc[:train_size], train_data.iloc[train_size:]\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Training and savinng model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"C:\\Users\\agraw\\AppData\\Roaming\\Python\\Python311\\site-packages\\statsmodels\\tsa\\base\\tsa_model.py:473: ValueWarning: A date index has been provided, but it has no associated frequency information and so will be ignored when e.g. forecasting.\n", | ||
" self._init_dates(dates, freq)\n", | ||
"C:\\Users\\agraw\\AppData\\Roaming\\Python\\Python311\\site-packages\\statsmodels\\tsa\\base\\tsa_model.py:473: ValueWarning: A date index has been provided, but it has no associated frequency information and so will be ignored when e.g. forecasting.\n", | ||
" self._init_dates(dates, freq)\n", | ||
"C:\\Users\\agraw\\AppData\\Roaming\\Python\\Python311\\site-packages\\statsmodels\\tsa\\statespace\\sarimax.py:978: UserWarning: Non-invertible starting MA parameters found. Using zeros as starting parameters.\n", | ||
" warn('Non-invertible starting MA parameters found.'\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Model and scaler saved successfully.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Train the SARIMAX model\n", | ||
"order = (2, 1, 2)\n", | ||
"seasonal_order = (1, 1, 1, 12)\n", | ||
"\n", | ||
"model = SARIMAX(train['Close'], exog=train[exog_features], order=order, seasonal_order=seasonal_order)\n", | ||
"model_fit = model.fit(disp=False)\n", | ||
"\n", | ||
"# Save the model to a file using pickle\n", | ||
"with open('sarimax_model.pkl', 'wb') as f:\n", | ||
" pickle.dump(model_fit, f)\n", | ||
"\n", | ||
"# Optionally save the scaler as well\n", | ||
"with open('scaler.pkl', 'wb') as f:\n", | ||
" pickle.dump(scaler, f)\n", | ||
"\n", | ||
"print(\"Model and scaler saved successfully.\")\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Loading saved model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Load the model and scaler from the files\n", | ||
"with open('sarimax_model.pkl', 'rb') as f:\n", | ||
" loaded_model = pickle.load(f)\n", | ||
"\n", | ||
"with open('scaler.pkl', 'rb') as f:\n", | ||
" loaded_scaler = pickle.load(f)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Loading and processing Test data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Load the test dataset\n", | ||
"test_data = pd.read_csv('../Data/SBI Test data.csv', parse_dates=['Date'], dayfirst=True)\n", | ||
"\n", | ||
"# Set the index to the Date column\n", | ||
"test_data.set_index('Date', inplace=True)\n", | ||
"\n", | ||
"# Apply the same feature engineering on the test data\n", | ||
"test_data['day_of_week'] = test_data.index.dayofweek\n", | ||
"test_data['month'] = test_data.index.month\n", | ||
"test_data['lagged_close'] = test_data['Close'].shift(1)\n", | ||
"test_data['moving_avg_3'] = test_data['Close'].rolling(window=3).mean()\n", | ||
"test_data['moving_avg_7'] = test_data['Close'].rolling(window=7).mean()\n", | ||
"\n", | ||
"# Add Volume as a feature\n", | ||
"test_data['volume'] = test_data['Volume']\n", | ||
"\n", | ||
"# Drop rows with NaN values\n", | ||
"test_data.dropna(inplace=True)\n", | ||
"\n", | ||
"# Standardize the features in the test dataset using the loaded scaler\n", | ||
"test_data[exog_features] = loaded_scaler.transform(test_data[exog_features])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Predicting share prices using model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"C:\\Users\\agraw\\AppData\\Roaming\\Python\\Python311\\site-packages\\statsmodels\\tsa\\base\\tsa_model.py:837: ValueWarning: No supported index is available. Prediction results will be given with an integer index beginning at `start`.\n", | ||
" return get_prediction_index(\n", | ||
"C:\\Users\\agraw\\AppData\\Roaming\\Python\\Python311\\site-packages\\statsmodels\\tsa\\base\\tsa_model.py:837: FutureWarning: No supported index is available. In the next version, calling this method in a model without a supported index will result in an exception.\n", | ||
" return get_prediction_index(\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Test Data RMSE: 4.673693537736142\n", | ||
"Mean of residuals: 0.311316834051805\n", | ||
"Standard deviation of residuals: 4.664969245987366\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Forecasting on the test data using the loaded model\n", | ||
"forecast_test = loaded_model.forecast(steps=len(test_data), exog=test_data[exog_features])\n", | ||
"\n", | ||
"# Calculate RMSE for forecast\n", | ||
"rmse_test = np.sqrt(mean_squared_error(test_data['Close'], forecast_test))\n", | ||
"print(f\"Test Data RMSE: {rmse_test}\")\n", | ||
"\n", | ||
"# Check residuals diagnostics (optional)\n", | ||
"test_prices = test_data['Close'].values\n", | ||
"residuals_test = test_prices - forecast_test\n", | ||
"print(\"Mean of residuals:\", residuals_test.mean())\n", | ||
"print(\"Standard deviation of residuals:\", residuals_test.std())" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.