-
Notifications
You must be signed in to change notification settings - Fork 4
/
profit_calculator.py
118 lines (88 loc) · 3.24 KB
/
profit_calculator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import sys
import argparse
from enum import IntEnum
from typing import List
import pandas as pd
class Action(IntEnum):
BUY = 1
NO_ACTION = 0
SOLD = -1
class InvalidActionError(Exception):
pass
class StockNumExceedError(Exception):
pass
class InvalidActionNumError(Exception):
pass
class StockTrader:
def __init__(self):
self.accumulated_profit = 0
self.holding_price = None
self.sell_short_price = None
@property
def is_holding_stock(self):
return self.holding_price is not None
@property
def is_shorting_stock(self):
return self.sell_short_price is not None
def perform_action(self, action_code: int, stock_price: float):
if action_code == Action.BUY:
self.buy(stock_price)
elif action_code == Action.SOLD:
self.sell(stock_price)
elif action_code == Action.NO_ACTION:
pass
else:
raise InvalidActionError('Invalid Action')
def buy(self, stock_price: float):
if self.is_holding_stock:
raise StockNumExceedError('You cannot buy stocks when you hold one')
elif self.is_shorting_stock:
self.accumulated_profit += self.sell_short_price - stock_price
self.sell_short_price = None
else:
self.holding_price = stock_price
def sell(self, stock_price: float):
if self.is_shorting_stock:
raise StockNumExceedError("You cannot sell short stocks when you've already sell short one")
elif self.is_holding_stock:
self.accumulated_profit += stock_price - self.holding_price
self.holding_price = None
else:
self.sell_short_price = stock_price
def check_stock_actions_length(stocks_df: pd.DataFrame, actions: List[int]) -> bool:
if len(stocks_df) != (len(actions) + 1):
return False
return True
def calculate_profit(stocks_df: pd.DataFrame, actions: List[int]) -> float:
stock_trader = StockTrader()
stock = None
stocks_df.drop(0, inplace=True)
for (_, stock), action in zip(stocks_df.iterrows(), actions):
stock_trader.perform_action(action, stock['open'])
if stock is not None:
if stock_trader.is_holding_stock:
stock_trader.sell(stock['close'])
elif stock_trader.is_shorting_stock:
stock_trader.buy(stock['close'])
return stock_trader.accumulated_profit
if __name__ == "__main__":
# Set up argument parser
parser = argparse.ArgumentParser()
parser.add_argument('stock',
help='input stock file name')
parser.add_argument('action',
help='input action file name')
args = parser.parse_args()
# Load stock data
FEATURE_NAMES = ('open', 'high', 'low', 'close')
stocks_df = pd.read_csv(args.stock, names=FEATURE_NAMES)
# Load actions
with open(args.action, 'r') as action_file:
actions = list()
for line in action_file.readlines():
action = int(line.strip())
actions.append(action)
if not check_stock_actions_length(stocks_df, actions):
raise InvalidActionNumError('Invalid number of actions')
profit = calculate_profit(stocks_df, actions)
print(profit)