Skip to content

Commit

Permalink
daily
Browse files Browse the repository at this point in the history
  • Loading branch information
dyh committed Jun 22, 2021
1 parent c4bee15 commit 4b58bcc
Show file tree
Hide file tree
Showing 13 changed files with 3,135 additions and 135 deletions.
217 changes: 150 additions & 67 deletions .idea/workspace.xml

Large diffs are not rendered by default.

28 changes: 20 additions & 8 deletions env_predict_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ def __init__(self, cwd='./envs/FinRL', gamma=0.99,
self.episode_return = 0.0

# 输出的缓存
self.output_text_cache = ''
self.output_text_trade_detail = ''

# 输出的list
self.list_buy_or_sell_output = []

pass

def reset(self):
Expand All @@ -72,7 +73,7 @@ def reset(self):

# ----
# 清空输出的缓存
self.output_text_cache = ''
self.output_text_trade_detail = ''

# 输出的list
self.list_buy_or_sell_output.clear()
Expand All @@ -97,7 +98,7 @@ def step(self, actions):
date_ary_temp = self.date_ary[self.day]
date_temp = date_ary_temp[0]

self.output_text_cache += f'第 {self.day + 1} 天,{date_temp}\r\n'
self.output_text_trade_detail += f'第 {self.day + 1} 天,{date_temp}\r\n'

for index in np.where(actions < 0)[0]: # sell_index:
if price[index] > 0: # Sell only if current asset is > 0
Expand All @@ -123,6 +124,7 @@ def step(self, actions):

# tic, date, sell/buy, hold, 第x天
episode_return_temp = (self.amount + (self.stocks * price).sum()) / self.initial_total_asset

list_item = (tic_temp, date_temp, -1 * sell_num_shares, self.stocks[index], self.day + 1,
episode_return_temp)
# 添加到输出list
Expand All @@ -135,14 +137,15 @@ def step(self, actions):

# tic, date, sell/buy, hold, 第x天
episode_return_temp = (self.amount + (self.stocks * price).sum()) / self.initial_total_asset

list_item = (tic_temp, date_temp, 0, self.stocks[index], self.day + 1, episode_return_temp)
# 添加到输出list
self.list_buy_or_sell_output.append(list_item)
pass
pass

price_diff = str(round(price[index] - yesterday_price[index], 6))
self.output_text_cache += f' > {tic_temp},' \
self.output_text_trade_detail += f' > {tic_temp},' \
f'卖出:{sell_num_shares} 股, 持股数量 {self.stocks[index]},' \
f'涨跌:¥{price_diff} 元,' \
f'现金:{self.amount},资产:{self.total_asset} \r\n'
Expand Down Expand Up @@ -186,14 +189,15 @@ def step(self, actions):

# tic, date, sell/buy, hold, 第x天
episode_return_temp = (self.amount + (self.stocks * price).sum()) / self.initial_total_asset

list_item = (tic_temp, date_temp, 0, self.stocks[index], self.day + 1, episode_return_temp)
# 添加到输出list
self.list_buy_or_sell_output.append(list_item)
pass
pass

price_diff = str(round(price[index] - yesterday_price[index], 6))
self.output_text_cache += f' > {tic_temp},' \
self.output_text_trade_detail += f' > {tic_temp},' \
f'买入:{buy_num_shares} 股, 持股数量:{self.stocks[index]},' \
f'涨跌:¥{price_diff} 元,' \
f'现金:{self.amount},资产:{self.total_asset} \r\n'
Expand All @@ -206,6 +210,7 @@ def step(self, actions):
# tic, date, sell/buy, hold, 第x天
tic_temp = tic_ary_temp[index]
episode_return_temp = (self.amount + (self.stocks * price).sum()) / self.initial_total_asset

list_item = (tic_temp, date_temp, 0, self.stocks[index], self.day + 1, episode_return_temp)
# 添加到输出list
self.list_buy_or_sell_output.append(list_item)
Expand All @@ -232,7 +237,7 @@ def step(self, actions):
# ----
if config.IF_SHOW_PREDICT_INFO:
# date_temp = date_ary_temp[index]
print(self.output_text_cache)
print(self.output_text_trade_detail)
print(f'第 {self.day + 1} 天,{date_temp},现金:{self.amount},'
f'股票:{str((self.stocks * price).sum())},总资产:{self.total_asset}')

Expand Down Expand Up @@ -314,15 +319,22 @@ def convert_df_to_ary(df, tech_indicator_list):
price_ary.append(item.close) # adjusted close price (adjcp)

# ----
tic_ary.append(list(item.tic))
date_ary.append(list(item.date))
# tic_ary.append(list(item.tic))
# date_ary.append(list(item.date))

tic_ary.append(item.tic)
date_ary.append(item.date)

# ----

pass

price_ary = np.array(price_ary)
tech_ary = np.array(tech_ary)

tic_ary = np.array(tic_ary)
date_ary = np.array(date_ary)

print(f'| price_ary.shape: {price_ary.shape}, tech_ary.shape: {tech_ary.shape}')
return price_ary, tech_ary, tic_ary, date_ary

Expand Down
53 changes: 47 additions & 6 deletions predict_single_insert_into_psql.py → predict_single_psql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,41 @@
from run_single import *
from datetime import datetime


def calc_max_return(price_ary, initial_capital_temp):
# ret = 0
max_return_temp = 0
# max_value = 0
min_value = 0

assert price_ary.shape[0] > 1

count_price = price_ary.shape[0]

for index_left in range(0, count_price - 1):

for index_right in range(index_left+1, count_price):

assert price_ary[index_left][0] > 0

assert price_ary[index_right][0] > 0

temp_value = price_ary[index_right][0] - price_ary[index_left][0]

if temp_value > max_return_temp:
max_return_temp = temp_value
# max_value = price_ary[index1][0]
min_value = price_ary[index_right][0]
pass

# print(price_ary[index][0])
pass

ret = (initial_capital_temp / min_value * max_return_temp + initial_capital_temp) / initial_capital_temp

return ret


if __name__ == '__main__':
# 预测,并保存结果到 postgresql 数据库
# 开始预测的时间
Expand All @@ -26,19 +61,18 @@
psql_object = Psqldb(database=config.PSQL_DATABASE, user=config.PSQL_USER,
password=config.PSQL_PASSWORD, host=config.PSQL_HOST, port=config.PSQL_PORT)

config.OUTPUT_DATE = '2021-06-21'
config.OUTPUT_DATE = '2021-06-23'

# 前10后10,前10后x,前x后10
config.PREDICT_PERIOD = '42'
config.PREDICT_PERIOD = '40'

# 好用 AgentPPO(), # AgentSAC(), AgentTD3(), AgentDDPG(), AgentModSAC(),
# AgentDoubleDQN 单进程好用?
# 不好用 AgentDuelingDQN(), AgentDoubleDQN(), AgentSharedSAC()
# for agent_item in ['AgentModSAC', ]:
# , 'AgentModSAC'
# for agent_item in ['AgentTD3', 'AgentPPO']:
# for agent_item in ['AgentSAC', 'AgentTD3', 'AgentDDPG', 'AgentPPO', 'AgentModSAC']:
for agent_item in ['AgentPPO', 'AgentDDPG', 'AgentTD3', 'AgentSAC', 'AgentModSAC']:
# for agent_item in ['AgentPPO', 'AgentDDPG', 'AgentTD3', 'AgentSAC', 'AgentModSAC']:
for agent_item in ['AgentPPO', 'AgentDDPG', 'AgentTD3', 'AgentSAC']:

config.AGENT_NAME = agent_item
# config.CWD = f'./{config.AGENT_NAME}/single/{config.SINGLE_A_STOCK_CODE[0]}/StockTradingEnv-v1'
Expand Down Expand Up @@ -305,13 +339,20 @@
for item in env.list_buy_or_sell_output:
tic, date, action, hold, day, episode_return = item
if str(date) == config.OUTPUT_DATE:
# 简单计算一次,低买高卖的最大回报
max_return = calc_max_return(env.price_ary, env.initial_capital)

# 找到要预测的那一天,存储到psql
StockData.update_predict_result_to_psql(psql=psql_object, agent=config.AGENT_NAME,
vali_period_value=config.VALI_DAYS_FLAG,
pred_period_name=config.PREDICT_PERIOD,
tic=tic, date=date, action=action,
hold=hold,
day=day, episode_return=episode_return)
day=day, episode_return=episode_return,
max_return=max_return)

break

pass
pass
pass
Expand Down
115 changes: 95 additions & 20 deletions predict_single_do_not_insert_into_psql.py → predict_single_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,43 @@
from run_single import *
from datetime import datetime


def calc_max_return(price_ary, initial_capital_temp):
# ret = 0
max_return_temp = 0
# max_value = 0
min_value = 0

assert price_ary.shape[0] > 1

count_price = price_ary.shape[0]

for index_left in range(0, count_price - 1):

for index_right in range(index_left + 1, count_price):

assert price_ary[index_left][0] > 0

assert price_ary[index_right][0] > 0

temp_value = price_ary[index_right][0] - price_ary[index_left][0]

if temp_value > max_return_temp:
max_return_temp = temp_value
# max_value = price_ary[index1][0]
min_value = price_ary[index_right][0]
pass

# print(price_ary[index][0])
pass

ret = (initial_capital_temp / min_value * max_return_temp + initial_capital_temp) / initial_capital_temp

return ret


if __name__ == '__main__':
# 预测,但不保存到 postgresql 数据库
# 预测,并保存结果到 postgresql 数据库
# 开始预测的时间
time_begin = datetime.now()

Expand All @@ -21,25 +56,26 @@
# 要预测的那一天
config.SINGLE_A_STOCK_CODE = [tic_item, ]

config.OUTPUT_DATE = '2021-06-17'
config.OUTPUT_DATE = '2021-06-23'

# 前10后10,前10后x,前x后10
config.PREDICT_PERIOD = '20'
config.PREDICT_PERIOD = '40'

# 好用 AgentPPO(), # AgentSAC(), AgentTD3(), AgentDDPG(), AgentModSAC(),
# AgentDoubleDQN 单进程好用?
# 不好用 AgentDuelingDQN(), AgentDoubleDQN(), AgentSharedSAC()
# for agent_item in ['AgentModSAC', ]:
# 'AgentDDPG', 'AgentPPO', 'AgentModSAC', 'AgentSAC'
for agent_item in ['AgentTD3', ]:
# , 'AgentModSAC'
# for agent_item in ['AgentPPO', 'AgentDDPG', 'AgentTD3', 'AgentSAC', 'AgentModSAC']:
for agent_item in ['AgentPPO', 'AgentDDPG', 'AgentTD3', 'AgentSAC']:

config.AGENT_NAME = agent_item
# config.CWD = f'./{config.AGENT_NAME}/single/{config.SINGLE_A_STOCK_CODE[0]}/StockTradingEnv-v1'

break_step = int(3e6)

if_on_policy = False
if_use_gae = False
# if_use_gae = False

# 预测的开始日期和结束日期,都固定

Expand All @@ -51,12 +87,12 @@
config.START_DATE = "2002-05-01"

# 向左10工作日
config.START_EVAL_DATE = str(get_next_work_day(get_datetime_from_date_str(config.OUTPUT_DATE), -17))
config.START_EVAL_DATE = str(get_next_work_day(get_datetime_from_date_str(config.OUTPUT_DATE), -39))
# 向右10工作日
config.END_DATE = str(get_next_work_day(get_datetime_from_date_str(config.OUTPUT_DATE), +3))

# 创建预测结果表
# StockData.create_predict_result_table_psql(tic=config.SINGLE_A_STOCK_CODE[0])
StockData.create_predict_result_table_sqlite(tic=config.SINGLE_A_STOCK_CODE[0])

# 更新股票数据
StockData.update_stock_data(tic_code=config.SINGLE_A_STOCK_CODE[0])
Expand Down Expand Up @@ -94,33 +130,45 @@
print('# initial_capital', initial_capital)
print('# max_stock', max_stock)

# Agent
args = Arguments(if_on_policy=if_on_policy)

agent_class = None
if config.AGENT_NAME == 'AgentPPO':
args.agent = AgentPPO()
agent_class = AgentPPO()
if_on_policy = True
pass
elif config.AGENT_NAME == 'AgentSAC':
args.agent = AgentSAC()
agent_class = AgentSAC()
if_on_policy = False
pass
elif config.AGENT_NAME == 'AgentTD3':
args.agent = AgentTD3()
agent_class = AgentTD3()
if_on_policy = False
pass
elif config.AGENT_NAME == 'AgentDDPG':
args.agent = AgentDDPG()
agent_class = AgentDDPG()
if_on_policy = False
pass
elif config.AGENT_NAME == 'AgentModSAC':
args.agent = AgentModSAC()
agent_class = AgentModSAC()
if_on_policy = False
pass
elif config.AGENT_NAME == 'AgentDuelingDQN':
args.agent = AgentDuelingDQN()
agent_class = AgentDuelingDQN()
if_on_policy = False
pass
elif config.AGENT_NAME == 'AgentSharedSAC':
args.agent = AgentSharedSAC()
agent_class = AgentSharedSAC()
if_on_policy = False
pass
elif config.AGENT_NAME == 'AgentDoubleDQN':
agent_class = AgentDoubleDQN()
if_on_policy = False
pass

args = Arguments(if_on_policy=if_on_policy)
args.agent = agent_class

args.gpu_id = 0
args.agent.if_use_gae = if_use_gae
# args.agent.if_use_gae = if_use_gae
args.agent.lambda_entropy = 0.04

tech_indicator_list = [
Expand Down Expand Up @@ -276,7 +324,33 @@
pass
pass

print('>>>> env.list_output', env.list_buy_or_sell_output)
# print('>>>> env.list_output', env.list_buy_or_sell_output)

print(env.output_text_trade_detail)

# 获取要预测的日期,保存到数据库中
for item in env.list_buy_or_sell_output:
tic, date, action, hold, day, episode_return = item
if str(date) == config.OUTPUT_DATE:
# 简单计算一次,低买高卖的最大回报金额
max_return = calc_max_return(env.price_ary, env.initial_capital)

# 找到要预测的那一天,存储到psql
StockData.update_predict_result_to_sqlite(agent=config.AGENT_NAME,
vali_period_value=config.VALI_DAYS_FLAG,
pred_period_name=config.PREDICT_PERIOD,
tic=tic, date=date, action=action,
hold=hold,
day=day, episode_return=episode_return,
max_return=max_return,
trade_detail=env.output_text_trade_detail)

break

pass
pass
pass
# episode_return = getattr(env, 'episode_return', episode_return)
pass
else:
print('未找到模型文件', model_file_path)
Expand All @@ -285,6 +359,7 @@

pass
pass

pass

# 结束预测的时间
Expand Down
Loading

0 comments on commit 4b58bcc

Please sign in to comment.