-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathpredict.py
41 lines (29 loc) · 1.27 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from sp500forecaster.args import parse_predict_args
from sp500forecaster import log, set_console_logger
from sp500forecaster.stock_data_transformer import StockDataTransformer
from sp500forecaster.forecaster import Forecaster
from sp500forecaster.utils import is_iex_supported, get_ohlcv
import os
def predict(args):
symbols = [symbol.upper() for symbol in args.symbols]
log.debug('symbols: %s', symbols)
transformer = StockDataTransformer()
forecaster = Forecaster(transformer)
forecaster.load_weights(args.weights)
[predict_future(symbol, transformer, forecaster) for symbol in symbols]
def predict_future(symbol, transformer, forecaster):
if not is_iex_supported(symbol):
log.debug('symbol %s is not supported by iex', symbol)
log.debug('processing %s', symbol)
last = transformer.build_latest_win(symbol, get_ohlcv(symbol))
prediction = forecaster.predict_classes(last, batch_size=1)[0][0]
status = 'positive' if prediction == 1 else 'negative'
log.debug('%s future prediction for symbol %s', status, symbol)
return prediction
if __name__ == '__main__':
args = parse_predict_args()
if args.verbose:
set_console_logger()
if not os.path.exists(args.output):
os.makedirs(args.output)
predict(args)