-
Notifications
You must be signed in to change notification settings - Fork 4
/
model_plot.py
87 lines (69 loc) · 1.97 KB
/
model_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import sys
import matplotlib.pyplot as plt
from src.parameters import market
from src.service.dataset_builder import DatasetBuilder
from src.service.predictor import Predictor
# Variables
# ------------------------------------------------------------------------
interval = sys.argv[1] # 5m, 15m, 30m ...
model_path = sys.argv[2] # /Users/ivan/code/ta/model/gru-g-50-5000-223-5m-BTC.keras
assets = [
'BTC',
'ETH',
'BNB',
'ADA',
]
tail = 50
width = 500
# Services
# ------------------------------------------------------------------------
predictor = Predictor(
assets=assets,
market=market,
interval=interval,
model_path='model/gru-g-50-1000-11-1m-BTCUSDT.keras',
width=width
)
dataset_builder = DatasetBuilder(
assets,
interval,
market,
)
# Data load
# Predicting close price on the next time interval
# ------------------------------------------------------------------------
collection = dataset_builder.build_dataset_predict(width=width)
for x_df in collection:
asset = x_df.iloc[-1]['asset']
x_df_shifted = x_df[:-1]
predictor.load_model()
y_df = predictor.make_prediction_ohlc_close(x_df=x_df_shifted)
# Plotting
# ------------------------------------------------------------------------
plt.figure(figsize=(16, 8))
plt.xlim(left=0)
plt.xlim(right=200)
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['axes.edgecolor'] = 'white'
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 1
plt.rcParams['grid.color'] = "#cccccc"
plt.grid(True)
plt.grid(which='minor', alpha=0.2)
plt.grid(which='major', alpha=0.5)
a = plt.subplot(2, 1, 1)
a.plot(
x_df['open'].tail(tail).values,
color='blue',
label='real',
marker='.'
)
b = plt.subplot(2, 1, 2)
plt.plot(
y_df['close'].tail(tail).values,
color='green',
label=f'predict {interval} {asset}',
marker='.'
)
plt.legend()
plt.show()