-
Notifications
You must be signed in to change notification settings - Fork 0
/
forecasting.py
53 lines (40 loc) · 1.49 KB
/
forecasting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from exploration.utils import regressors, suppress_log
import os
from prophet import Prophet
import pandas as pd
import datetime
from datetime import date, datetime, timedelta
import matplotlib.pyplot as plt
import itertools
import logging
import json
import warnings
import tqdm
import numpy as np
from prophet.diagnostics import cross_validation
from prophet.diagnostics import performance_metrics
import seaborn as sns
import multiprocessing
sns.set(style="darkgrid")
warnings.simplefilter(action='ignore')
logging.getLogger("fbprophet").setLevel(logging.ERROR)
logging.getLogger("cmdstanpy").setLevel(logging.ERROR)
def produce_cv(df,params=None,regr=None,cutoffs=None,horizon=None):
with suppress_log.suppress_stdout_stderr():
if params is not None:
m = Prophet(**params) # Fit model with given params
else:
m = Prophet()
m.add_country_holidays(country_name='UK')
df = regressors.produce_flags(df[['ds','y']])
if regr is not None:
df = df[['ds','y']+regr]
for col in df.columns[2:]:
m.add_regressor(col)
m.fit(df)
if horizon is None:
df_cv = cross_validation(m,cutoffs=cutoffs, horizon = '30 days', parallel="processes")
else:
df_cv = cross_validation(m,cutoffs=cutoffs, horizon = horizon, parallel="processes")
df_p = performance_metrics(df_cv, rolling_window=1)
return [df_p['smape'].values[0],df_cv]