-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathtrain_batch.py
285 lines (216 loc) · 11.4 KB
/
train_batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
import os
import shutil
import time
from datetime import datetime
import config
import numpy as np
from stock_data import StockData
from agent import AgentPPO, AgentSAC, AgentTD3, AgentDDPG, AgentModSAC, AgentDuelingDQN, AgentSharedSAC, \
AgentDoubleDQN
from run_batch import Arguments, train_and_evaluate_mp, train_and_evaluate
from train_helper import init_model_hyper_parameters_table_sqlite, query_model_hyper_parameters_sqlite, \
update_model_hyper_parameters_by_train_history, clear_train_history_table_sqlite
from utils import date_time
from utils.date_time import get_datetime_from_date_str, get_next_work_day, get_today_date
from env_batch import StockTradingEnvBatch
if __name__ == '__main__':
# 开始预测的时间
time_begin = datetime.now()
# 初始化超参表
init_model_hyper_parameters_table_sqlite()
fe_table_name = 'fe_fillzero_train'
# 股票的顺序,不要改变
# config.BATCH_A_STOCK_CODE = ['sz.000028', 'sh.600585', 'sz.000538', 'sh.600036']
config.START_DATE = "2004-05-01"
config.BATCH_A_STOCK_CODE = StockData.get_batch_a_share_code_list_string(table_name='tic_list_275')
# 初始现金,每只股票15万元
initial_capital = 150000 * len(config.BATCH_A_STOCK_CODE)
# 单次 购买/卖出 最大股数
max_stock = 50000
initial_stocks_train = np.ones(len(config.BATCH_A_STOCK_CODE), dtype=np.float32)
initial_stocks_vali = np.ones(len(config.BATCH_A_STOCK_CODE), dtype=np.float32)
# 默认持有0-3000股
initial_stocks_train = max_stock * initial_stocks_train
initial_stocks_vali = max_stock * initial_stocks_vali
config.IF_ACTUAL_PREDICT = False
config.START_EVAL_DATE = ""
# TODO 因为数据量大,在 stock_data.py 中更新
# # 更新股票数据,不复权
# StockData.update_stock_data_to_sqlite(list_stock_code=config.BATCH_A_STOCK_CODE, adjustflag='3',
# table_name=fe_table_name, if_incremental_update=False)
# 好用 AgentPPO(), # AgentSAC(), AgentTD3(), AgentDDPG(), AgentModSAC(),
# AgentDoubleDQN 单进程好用?
# 不好用 AgentDuelingDQN(), AgentDoubleDQN(), AgentSharedSAC()
loop_index = 0
# 循环
while True:
# 清空训练历史记录表
clear_train_history_table_sqlite()
# 从 model_hyper_parameters 表中,找到 training_times 最小的记录
# 获取超参
hyper_parameters_id, hyper_parameters_model_name, if_on_policy, break_step, train_reward_scale, \
eval_reward_scale, training_times, time_point, state_amount_scale, state_price_scale, state_stocks_scale, \
state_tech_scale = query_model_hyper_parameters_sqlite()
if if_on_policy == 'True':
if_on_policy = True
else:
if_on_policy = False
pass
config.MODEL_HYPER_PARAMETERS = str(hyper_parameters_id)
# 获得Agent参数
agent_class = None
# 模型名称
config.AGENT_NAME = str(hyper_parameters_model_name).split('_')[0]
# 模型预测周期
config.AGENT_WORK_DAY = int(str(hyper_parameters_model_name).split('_')[1])
if config.AGENT_NAME == 'AgentPPO':
agent_class = AgentPPO()
elif config.AGENT_NAME == 'AgentSAC':
agent_class = AgentSAC()
elif config.AGENT_NAME == 'AgentTD3':
agent_class = AgentTD3()
elif config.AGENT_NAME == 'AgentDDPG':
agent_class = AgentDDPG()
elif config.AGENT_NAME == 'AgentModSAC':
agent_class = AgentModSAC()
elif config.AGENT_NAME == 'AgentDuelingDQN':
agent_class = AgentDuelingDQN()
elif config.AGENT_NAME == 'AgentSharedSAC':
agent_class = AgentSharedSAC()
elif config.AGENT_NAME == 'AgentDoubleDQN':
agent_class = AgentDoubleDQN()
pass
# 更新工作日标记,用于 run_single.py 加载训练过的 weights 文件
config.VALI_DAYS_FLAG = str(config.AGENT_WORK_DAY)
# TODO 整体结束日期,今天的日期,预留60个工作日,用于验证predict
# config.END_DATE = str(get_next_work_day(get_datetime_from_date_str(get_today_date()),
# -1 * config.AGENT_WORK_DAY))
config.END_DATE = str(get_next_work_day(get_datetime_from_date_str(get_today_date()), -60))
model_folder_path = f'./{config.WEIGHTS_PATH}/batch/{config.AGENT_NAME}/' \
f'batch_{config.VALI_DAYS_FLAG}'
if not os.path.exists(model_folder_path):
os.makedirs(model_folder_path)
pass
# 开始预测的日期
config.START_EVAL_DATE = str(
get_next_work_day(get_datetime_from_date_str(get_today_date()), -2 * config.AGENT_WORK_DAY))
print('\r\n')
print('-' * 40)
print('config.AGENT_NAME', config.AGENT_NAME)
print('# 训练-预测周期', config.START_DATE, '-', config.START_EVAL_DATE, '-', config.END_DATE)
print('# work_days', config.AGENT_WORK_DAY)
print('# model_folder_path', model_folder_path)
print('# initial_capital', initial_capital)
print('# max_stock', max_stock)
args = Arguments(if_on_policy=if_on_policy)
args.agent = agent_class
# args.agent.if_use_gae = if_use_gae
args.agent.lambda_entropy = 0.04
args.gpu_id = 0
tech_indicator_list = [
'macd', 'boll_ub', 'boll_lb', 'rsi_30', 'cci_30', 'dx_30',
'close_30_sma', 'close_60_sma'] # finrl.config.TECHNICAL_INDICATORS_LIST
gamma = 0.99
buy_cost_pct = 0.003
sell_cost_pct = 0.003
start_date = config.START_DATE
start_eval_date = config.START_EVAL_DATE
end_eval_date = config.END_DATE
# train
args.env = StockTradingEnvBatch(cwd='', gamma=gamma, max_stock=max_stock,
initial_capital=initial_capital,
buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct, start_date=start_date,
end_date=start_eval_date, env_eval_date=end_eval_date,
ticker_list=config.BATCH_A_STOCK_CODE,
tech_indicator_list=tech_indicator_list, initial_stocks=initial_stocks_train,
if_eval=False, fe_table_name=fe_table_name)
# eval
args.env_eval = StockTradingEnvBatch(cwd='', gamma=gamma, max_stock=max_stock,
initial_capital=initial_capital,
buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct,
start_date=start_date,
end_date=start_eval_date, env_eval_date=end_eval_date,
ticker_list=config.BATCH_A_STOCK_CODE,
tech_indicator_list=tech_indicator_list,
initial_stocks=initial_stocks_vali,
if_eval=True, fe_table_name=fe_table_name)
args.env.target_return = 100
args.env_eval.target_return = 100
# 奖励 比例
args.env.reward_scale = train_reward_scale
args.env_eval.reward_scale = eval_reward_scale
args.env.state_amount_scale = state_amount_scale
args.env.state_price_scale = state_price_scale
args.env.state_stocks_scale = state_stocks_scale
args.env.state_tech_scale = state_tech_scale
args.env_eval.state_amount_scale = state_amount_scale
args.env_eval.state_price_scale = state_price_scale
args.env_eval.state_stocks_scale = state_stocks_scale
args.env_eval.state_tech_scale = state_tech_scale
print('train reward_scale', args.env.reward_scale)
print('eval reward_scale', args.env_eval.reward_scale)
print('state_amount_scale', state_amount_scale)
print('state_price_scale', state_price_scale)
print('state_stocks_scale', state_stocks_scale)
print('state_tech_scale', state_tech_scale)
# Hyperparameters
args.gamma = gamma
# args.gamma = 0.99
# reward_scaling 在 args.env里调整了,这里不动
# args.reward_scale = 2 ** 0
args.reward_scale = 1
# args.break_step = int(break_step / 30)
args.break_step = break_step
print('break_step', args.break_step)
args.net_dim = 2 ** 9
args.max_step = args.env.max_step
# args.max_memo = args.max_step * 4
args.max_memo = (args.max_step - 1) * 8
args.batch_size = 2 ** 12
# args.batch_size = 2305
print('batch_size', args.batch_size)
# ----
# args.repeat_times = 2 ** 3
args.repeat_times = 2 ** 4
# ----
args.eval_gap = 2 ** 4
args.eval_times1 = 2 ** 3
args.eval_times2 = 2 ** 5
args.if_allow_break = False
args.rollout_num = 2 # the number of rollout workers (larger is not always faster)
# train_and_evaluate(args)
train_and_evaluate_mp(args) # the training process will terminate once it reaches the target reward.
# 保存训练后的模型
shutil.copyfile(f'./{config.WEIGHTS_PATH}/StockTradingEnv-v1/actor.pth', f'{model_folder_path}/actor.pth')
# 多留一份
shutil.copyfile(f'./{config.WEIGHTS_PATH}/StockTradingEnv-v1/actor.pth',
f'{model_folder_path}/{date_time.time_point(time_format="%Y%m%d_%H%M%S")}.pth')
# 保存训练曲线图
# plot_learning_curve.jpg
timepoint_temp = date_time.time_point()
plot_learning_curve_file_path = f'{model_folder_path}/{timepoint_temp}.jpg'
shutil.copyfile(f'./{config.WEIGHTS_PATH}/StockTradingEnv-v1/plot_learning_curve.jpg',
plot_learning_curve_file_path)
# 训练结束后,model_hyper_parameters 表 中的 训练的次数 +1,训练的时间点 更新。
# 判断 train_history 表,是否有记录,如果有,则整除 256 + 128。将此值更新到 model_hyper_parameters 表的 超参,减去相应的值。
update_model_hyper_parameters_by_train_history(model_hyper_parameters_id=hyper_parameters_id,
origin_train_reward_scale=train_reward_scale,
origin_eval_reward_scale=eval_reward_scale,
origin_training_times=training_times,
origin_state_amount_scale=state_amount_scale,
origin_state_price_scale=state_price_scale,
origin_state_stocks_scale=state_stocks_scale,
origin_state_tech_scale=state_tech_scale)
print('>', config.AGENT_NAME, break_step, 'steps')
# 结束预测的时间
time_end = datetime.now()
duration = (time_end - time_begin).total_seconds()
print('检测耗时', duration, '秒')
# 循环次数
loop_index += 1
print('>', 'while 循环次数', loop_index, '\r\n')
print('sleep 10 秒\r\n')
time.sleep(10)
# TODO 训练一次退出
break
pass