diff --git a/examples/em_alpaca.py b/examples/em_alpaca.py index 9ee6f518..81117cf7 100644 --- a/examples/em_alpaca.py +++ b/examples/em_alpaca.py @@ -13,50 +13,48 @@ import pandas as pd import mplfinance as mpf + class EMAlgo(BaseAlgo): def setup(self): now = dt.datetime.now() - logging.info(f'EMAlgo.setup ran at: {now}') + logging.info(f"EMAlgo.setup ran at: {now}") def init_ticker(ticker): - return { - ticker: { - 'initial_price': None, - 'ohlc': pd.DataFrame() - } - } + return {ticker: {"initial_price": None, "ohlc": pd.DataFrame()}} self.tickers = {} - self.tickers.update(init_ticker('AAPL')) - self.tickers.update(init_ticker('MSFT')) + self.tickers.update(init_ticker("AAPL")) + self.tickers.update(init_ticker("MSFT")) def main(self): now = dt.datetime.now() - logging.info(f'EMAlgo.main ran at: {now}') + logging.info(f"EMAlgo.main ran at: {now}") - if now - now.replace(hour=0, minute=0, second=0, microsecond=0) <= dt.timedelta(seconds=60): - logger.info(f'It\'s a new day! Clearning OHLC caches!') + if now - now.replace(hour=0, minute=0, second=0, microsecond=0) <= dt.timedelta( + seconds=60 + ): + logger.info(f"It's a new day! Clearning OHLC caches!") for ticker_value in self.tickers.values(): - ticker_value['ohlc'] = pd.DataFrame() + ticker_value["ohlc"] = pd.DataFrame() for ticker, ticker_value in self.tickers.items(): current_price = self.get_asset_price(ticker) current_ohlc = self.get_asset_candle_list(ticker) - if ticker_value['initial_price'] is None: - ticker_value['initial_price'] = current_price + if ticker_value["initial_price"] is None: + ticker_value["initial_price"] = current_price self.process_ticker(ticker, ticker_value, current_price, current_ohlc) def process_ticker(self, ticker, ticker_data, current_price, current_ohlc): - initial_price = ticker_data['initial_price'] - ohlc = ticker_data['ohlc'] + initial_price = ticker_data["initial_price"] + ohlc = ticker_data["ohlc"] # Calculate the price change delta_price = current_price - initial_price # Print stock info - logging.info(f'{ticker} current price: ${current_price}') - logging.info(f'{ticker} price change: ${delta_price}') + logging.info(f"{ticker} current price: ${current_price}") + logging.info(f"{ticker} price change: ${delta_price}") # Update the OHLC data print("ohlc", ohlc) @@ -64,21 +62,21 @@ def process_ticker(self, ticker, ticker_data, current_price, current_ohlc): mpf.plot(ohlc) -if __name__ == '__main__': +if __name__ == "__main__": # Store the OHLC data in a folder called `em_storage` with each file stored as a csv document - csv_storage = CSVStorage(save_dir='em_storage') + csv_storage = CSVStorage(save_dir="em_storage") # Our streamer and broker will be Alpaca. My secret keys are stored in `alpaca_secret.yaml` - alpaca = Alpaca(path='accounts/alpaca-secret.yaml', is_basic_account=True, paper_trader=True) + alpaca = Alpaca( + path="accounts/alpaca-secret.yaml", is_basic_account=True, paper_trader=True + ) em_algo = EMAlgo() trader = LiveTrader(streamer=alpaca, broker=alpaca, storage=csv_storage, debug=True) # Watch for Apple and Microsoft - trader.set_symbol('AAPL') - trader.set_symbol('MSFT') + trader.set_symbol("AAPL") + trader.set_symbol("MSFT") trader.set_algo(em_algo) # Update every minute - trader.start('1MIN', all_history=False) - - + trader.start("1MIN", all_history=False) diff --git a/examples/em_kraken.py b/examples/em_kraken.py index 909dd88b..3ac92857 100644 --- a/examples/em_kraken.py +++ b/examples/em_kraken.py @@ -24,13 +24,14 @@ def init_ticker(ticker): fig = mpf.figure() ax1 = fig.add_subplot(2, 1, 1) ax2 = fig.add_subplot(3, 1, 3) - + return { ticker: { - "initial_price": None, "ohlc": pd.DataFrame(), - "fig": fig, - "ax1": ax1, - "ax2": ax2 + "initial_price": None, + "ohlc": pd.DataFrame(), + "fig": fig, + "ax1": ax1, + "ax2": ax2, } } @@ -56,7 +57,9 @@ def main(self): ticker_value["initial_price"] = current_price if current_ohlc.empty: - logging.warn(f"{ticker}'s get_asset_candle_list returned an empty list.") + logging.warn( + f"{ticker}'s get_asset_candle_list returned an empty list." + ) return ticker_value["ohlc"] = ticker_value["ohlc"].append(current_ohlc) @@ -75,9 +78,9 @@ def process_ticker(self, ticker, ticker_data, current_price): logging.info(f"{ticker} price change: ${delta_price}") # Update the OHLC graph - ticker_data['ax1'].clear() - ticker_data['ax2'].clear() - mpf.plot(ohlc, ax=ticker_data['ax1'], volume=ticker_data['ax2'], type="candle") + ticker_data["ax1"].clear() + ticker_data["ax2"].clear() + mpf.plot(ohlc, ax=ticker_data["ax1"], volume=ticker_data["ax2"], type="candle") plt.pause(3) @@ -85,9 +88,7 @@ def process_ticker(self, ticker, ticker_data, current_price): # Store the OHLC data in a folder called `em_storage` with each file stored as a csv document csv_storage = CSVStorage(save_dir="em_storage") # Our streamer and broker will be Alpaca. My secret keys are stored in `alpaca_secret.yaml` - kraken = Kraken( - path="accounts/kraken-secret.yaml" - ) + kraken = Kraken(path="accounts/kraken-secret.yaml") em_algo = EMAlgo() trader = LiveTrader(streamer=kraken, broker=kraken, storage=csv_storage, debug=True) diff --git a/examples/em_polygon.py b/examples/em_polygon.py index 95ebcda6..18c5cd6b 100644 --- a/examples/em_polygon.py +++ b/examples/em_polygon.py @@ -15,6 +15,7 @@ import matplotlib.pyplot as plt import mplfinance as mpf + class EMAlgo(BaseAlgo): def config(self): self.watchlist = ["@BTC"] @@ -22,15 +23,10 @@ def config(self): def setup(self): now = dt.datetime.now() - logging.info(f'EMAlgo.setup ran at: {now}') + logging.info(f"EMAlgo.setup ran at: {now}") def init_ticker(ticker): - return { - ticker: { - 'initial_price': None, - 'ohlc': pd.DataFrame() - } - } + return {ticker: {"initial_price": None, "ohlc": pd.DataFrame()}} self.tickers = {} for ticker in self.watchlist: @@ -38,13 +34,18 @@ def init_ticker(ticker): def main(self): now = dt.datetime.now() - logging.info('*' * 20) - logging.info(f'EMAlgo.main ran at: {now}') + logging.info("*" * 20) + logging.info(f"EMAlgo.main ran at: {now}") - if now - now.replace(hour=0, minute=0, second=0, microsecond=0) <= dt.timedelta(seconds=60): - logger.info(f'It\'s a new day! Clearning OHLC caches!') + if now - now.replace(hour=0, minute=0, second=0, microsecond=0) <= dt.timedelta( + seconds=60 + ): + logger.info(f"It's a new day! Clearning OHLC caches!") for ticker_value in self.tickers.values(): - ticker_value['ohlc'] = pd.DataFrame(columns=['open', 'high', 'low', 'close', 'volume'], index=['timestamp']) + ticker_value["ohlc"] = pd.DataFrame( + columns=["open", "high", "low", "close", "volume"], + index=["timestamp"], + ) for ticker, ticker_value in self.tickers.items(): current_price = self.get_asset_price(ticker) @@ -52,20 +53,22 @@ def main(self): if current_ohlc is None: logging.warn("No ohlc returned!") return - ticker_value['ohlc'] = ticker_value['ohlc'].append(current_ohlc) - ticker_value['ohlc'] = ticker_value['ohlc'][~ticker_value['ohlc'].index.duplicated(keep='first')] + ticker_value["ohlc"] = ticker_value["ohlc"].append(current_ohlc) + ticker_value["ohlc"] = ticker_value["ohlc"][ + ~ticker_value["ohlc"].index.duplicated(keep="first") + ] - if ticker_value['initial_price'] is None: - ticker_value['initial_price'] = current_price + if ticker_value["initial_price"] is None: + ticker_value["initial_price"] = current_price - logging.info('-' * 5 + ticker + '-' * 5) + logging.info("-" * 5 + ticker + "-" * 5) self.process_ticker(ticker, ticker_value, current_price) - logging.info('-' * 20) - logging.info('*' * 20) + logging.info("-" * 20) + logging.info("*" * 20) def process_ticker(self, ticker, ticker_data, current_price): - initial_price = ticker_data['initial_price'] - ohlc = ticker_data['ohlc'] + initial_price = ticker_data["initial_price"] + ohlc = ticker_data["ohlc"] if ohlc.empty: logging.warning(f"{ticker} does not have ohlc info! Not processing.") @@ -75,18 +78,21 @@ def process_ticker(self, ticker, ticker_data, current_price): delta_price = current_price - initial_price # Print stock info - logging.info(f'{ticker} current price: ${current_price}') - logging.info(f'{ticker} price change: ${delta_price}') + logging.info(f"{ticker} current price: ${current_price}") + logging.info(f"{ticker} price change: ${delta_price}") axes.clear() mpf.plot(ohlc, ax=axes, block=False) plt.pause(3) -if __name__ == '__main__': + +if __name__ == "__main__": # Store the OHLC data in a folder called `em_storage` with each file stored as a csv document - csv_storage = CSVStorage(save_dir='em-polygon-storage') + csv_storage = CSVStorage(save_dir="em-polygon-storage") # Our streamer will be Polygon and the broker will be Harvest's paper trader. My secret keys are stored in `polygon-secret.yaml` - polygon = PolygonStreamer(path='accounts/polygon-secret.yaml', is_basic_account=True) + polygon = PolygonStreamer( + path="accounts/polygon-secret.yaml", is_basic_account=True + ) paper = PaperBroker() em_algo = EMAlgo() trader = LiveTrader(streamer=polygon, broker=paper, storage=csv_storage, debug=True) @@ -96,6 +102,4 @@ def process_ticker(self, ticker, ticker_data, current_price): fig = mpf.figure() axes = fig.add_subplot(1, 1, 1) # Update every minute - trader.start('1MIN', all_history=False) - - + trader.start("1MIN", all_history=False) diff --git a/harvest/algo.py b/harvest/algo.py index 18610722..a7caabd8 100644 --- a/harvest/algo.py +++ b/harvest/algo.py @@ -41,6 +41,9 @@ def __init__(self): self.aggregations = None self.watchlist = [] + def init(self, stats): + self.stats = stats + def config(self): """ This method is called before any other methods (except for __init__), @@ -248,8 +251,8 @@ def filter_option_chain( """ if symbol is None: symbol = self.watchlist[0] - lower_exp = convert_input_to_datetime(lower_exp, self.trader.timezone) - upper_exp = convert_input_to_datetime(upper_exp, self.trader.timezone) + lower_exp = convert_input_to_datetime(lower_exp, self.stats.timezone) + upper_exp = convert_input_to_datetime(upper_exp, self.stats.timezone) exp_dates = self.get_option_chain_info(symbol)["exp_dates"] if lower_exp is not None: @@ -303,7 +306,7 @@ def get_option_chain(self, symbol: str, date): """ if symbol is None: symbol = self.watchlist[0] - date = convert_input_to_datetime(date, self.trader.timezone) + date = convert_input_to_datetime(date, self.stats.timezone) print(f"Date: {date}\n") return self.trader.fetch_chain_data(symbol, date) @@ -336,7 +339,7 @@ def _default_param(self, symbol, interval, ref, prices): raise Exception(f"No prices found for symbol {symbol}") else: if interval is None: - interval = self.trader.interval[symbol]["interval"] + interval = self.trader.stats.interval[symbol]["interval"] else: interval = interval_string_to_enum(interval) if prices == None: @@ -610,7 +613,7 @@ def get_asset_candle(self, symbol: str, interval=None) -> pd.DataFrame(): if len(symbol) <= 6: df = self.trader.storage.load(symbol, interval).iloc[[-1]][symbol] print(self.trader.storage.load(symbol, interval)) - return pandas_timestamp_to_local(df, self.trader.timezone) + return pandas_timestamp_to_local(df, self.stats.timezone) debugger.warning("Candles not available for options") return None @@ -639,7 +642,7 @@ def get_asset_candle_list( if interval is None: interval = self.interval df = self.trader.storage.load(symbol, interval)[symbol] - return pandas_timestamp_to_local(df, self.trader.timezone) + return pandas_timestamp_to_local(df, self.stats.timezone) def get_asset_returns(self, symbol=None) -> float: """Returns the return of a specified asset. @@ -770,9 +773,7 @@ def get_datetime(self): :returns: The current date and time as a datetime object """ - return datetime_utc_to_local( - self.trader.streamer.timestamp, self.trader.timezone - ) + return datetime_utc_to_local(self.stats.timestamp, self.stats.timezone) def get_option_position_quantity(self, symbol: str = None) -> bool: """Returns the number of types of options held for a stock. diff --git a/harvest/api/_base.py b/harvest/api/_base.py index 0351b309..97ad66b5 100644 --- a/harvest/api/_base.py +++ b/harvest/api/_base.py @@ -81,7 +81,7 @@ def refresh_cred(self): """ debugger.info(f"Refreshing credentials for {type(self).__name__}.") - def setup(self, interval: Dict, trader_main=None) -> None: + def setup(self, stats: Stats, trader_main=None) -> None: """ This function is called right before the algorithm begins, and initializes several runtime parameters like @@ -91,10 +91,12 @@ def setup(self, interval: Dict, trader_main=None) -> None: """ self.trader_main = trader_main + self.stats = stats + self.stats.timestamp = now() min_interval = None - for sym in interval: - inter = interval[sym]["interval"] + for sym in stats.interval: + inter = stats.interval[sym]["interval"] # If the specified interval is not supported on this API, raise Exception if inter < self.interval_list[0]: raise Exception(f"Specified interval {inter} is not supported.") @@ -103,13 +105,13 @@ def setup(self, interval: Dict, trader_main=None) -> None: if inter not in self.interval_list: granular_int = [i for i in self.interval_list if i < inter] new_inter = granular_int[-1] - interval[sym]["aggregations"].append(inter) - interval[sym]["interval"] = new_inter + stats.interval[sym]["aggregations"].append(inter) + stats.interval[sym]["interval"] = new_inter - if min_interval is None or interval[sym]["interval"] < min_interval: - min_interval = interval[sym]["interval"] + if min_interval is None or stats.interval[sym]["interval"] < min_interval: + min_interval = stats.interval[sym]["interval"] - self.interval = interval + self.interval = stats.interval self.poll_interval = min_interval debugger.debug(f"Interval: {self.interval}") debugger.debug(f"Poll Interval: {self.poll_interval}") @@ -134,7 +136,7 @@ def start(self): cur = now() minutes = cur.minute if minutes % val == 0 and minutes != cur_min: - self.timestamp = cur + self.stats.timestamp = cur self.main() time.sleep(sleep) cur_min = minutes @@ -144,7 +146,7 @@ def start(self): cur = now() minutes = cur.minute if minutes == 0 and minutes != cur_min: - self.timestamp = cur + self.stats.timestamp = cur self.main() time.sleep(sleep) cur_min = minutes @@ -154,7 +156,7 @@ def start(self): minutes = cur.minute hours = cur.hour if hours == 19 and minutes == 50: - self.timestamp = cur + self.stats.timestamp = cur self.main() time.sleep(80000) cur_min = minutes @@ -179,8 +181,8 @@ def main(self): df_dict = {} for sym in self.interval: inter = self.interval[sym]["interval"] - if is_freq(harvest_timestamp, inter): - n = harvest_timestamp + if is_freq(self.stats.timestamp, inter): + n = self.stats.timestamp latest = self.fetch_price_history( sym, inter, n - interval_to_timedelta(inter) * 2, n ) @@ -852,7 +854,7 @@ def main(self, df_dict): for sym in self.interval if is_freq(now(), self.interval[sym]["interval"]) ] - self.timestamp = df_dict[got[0]].index[0] + self.stats.timestamp = df_dict[got[0]].index[0] debugger.debug(f"Needs: {self.needed}") debugger.debug(f"Got data for: {got}") diff --git a/harvest/api/yahoo.py b/harvest/api/yahoo.py index 0d974902..b10c940a 100644 --- a/harvest/api/yahoo.py +++ b/harvest/api/yahoo.py @@ -31,7 +31,7 @@ def setup(self, interval: Dict, trader_main=None): self.watch_ticker = {} - for s in interval: + for s in self.stats.interval: if is_crypto(s): self.watch_ticker[s] = yf.Ticker(s[1:] + "-USD") else: @@ -216,7 +216,8 @@ def fetch_option_market_data(self, occ_symbol: str): chain = self.watch_ticker[symbol].option_chain(date_to_str(date)) chain = chain.calls if typ == "call" else chain.puts df = chain[chain["contractSymbol"] == occ_symbol] - debugger.debug(occ_symbol, df) + + debugger.debug(df) return { "price": float(df["lastPrice"].iloc[0]), "ask": float(df["ask"].iloc[0]), diff --git a/harvest/trader/tester.py b/harvest/trader/tester.py index a9513914..e89c6462 100644 --- a/harvest/trader/tester.py +++ b/harvest/trader/tester.py @@ -66,10 +66,11 @@ def start( a.config() self._setup(source, interval, aggregations, path, start, end, period) - self.broker.setup(self.interval, self.main) - self.streamer.setup(self.interval, self.main) + self.broker.setup(self.stats, self.main) + self.streamer.setup(self.stats, self.main) for a in self.algo: + a.init(self.stats) a.setup() a.trader = self @@ -92,8 +93,8 @@ def _setup( self.storage.limit_size = False - start = convert_input_to_datetime(start, self.timezone) - end = convert_input_to_datetime(end, self.timezone) + start = convert_input_to_datetime(start, self.stats.timezone) + end = convert_input_to_datetime(end, self.stats.timezone) period = convert_input_to_timedelta(period) if start is None: @@ -116,10 +117,12 @@ def _setup( common_start = None common_end = None - for s in self.interval: - for i in [self.interval[s]["interval"]] + self.interval[s]["aggregations"]: + for s in self.stats.interval: + for i in [self.stats.interval[s]["interval"]] + self.stats.interval[s][ + "aggregations" + ]: df = self.storage.load(s, i) - df = pandas_datetime_to_utc(df, self.timezone) + df = pandas_datetime_to_utc(df, self.stats.timezone) if common_start is None or df.index[0] > common_start: common_start = df.index[0] if common_end is None or df.index[-1] < common_end: @@ -144,8 +147,10 @@ def _setup( print(f"Common start: {start}, common end: {end}") - for s in self.interval: - for i in [self.interval[s]["interval"]] + self.interval[s]["aggregations"]: + for s in self.stats.interval: + for i in [self.stats.interval[s]["interval"]] + self.stats.interval[s][ + "aggregations" + ]: df = self.storage.load(s, i).copy() df = df.loc[start:end] self.storage.reset(s, i) @@ -161,14 +166,14 @@ def _setup( } # Generate the "simulated aggregation" data - for sym in self.interval: - interval = self.interval[sym]["interval"] + for sym in self.stats.interval: + interval = self.stats.interval[sym]["interval"] interval_txt = interval_enum_to_string(interval) df = self.storage.load(sym, interval) df_len = len(df.index) debugger.debug(f"Formatting {sym} data...") - for agg in self.interval[sym]["aggregations"]: + for agg in self.stats.interval[sym]["aggregations"]: agg_txt = interval_enum_to_string(agg) # tmp_path = f"{path}/{sym}-{interval_txt}+{agg_txt}.pickle" tmp_path = f"{path}/{sym}@{int(agg)-16}.pickle" @@ -196,10 +201,10 @@ def _setup( save_pickle=False, ) debugger.debug("Formatting complete") - for sym in self.interval: - for agg in self.interval[sym]["aggregations"]: + for sym in self.stats.interval: + for agg in self.stats.interval[sym]["aggregations"]: data = self.storage.load(sym, int(agg) - 16) - data = pandas_datetime_to_utc(data, self.timezone) + data = pandas_datetime_to_utc(data, self.stats.timezone) self.storage.store( sym, int(agg) - 16, @@ -209,20 +214,20 @@ def _setup( # # Save the current state of the queue # for s in self.watch: - # self.load.append_entry(s, self.interval, self.storage.load(s, self.interval)) + # self.load.append_entry(s, self.stats.interval, self.storage.load(s, self.stats.interval)) # for i in self.aggregations: # self.load.append_entry(s, '-'+i, self.storage.load(s, '-'+i), False, True) # self.load.append_entry(s, i, self.storage.load(s, i)) # Move all data to a cached dataframe - for sym in self.interval: + for sym in self.stats.interval: self.df[sym] = {} - inter = self.interval[sym]["interval"] + inter = self.stats.interval[sym]["interval"] interval_txt = interval_enum_to_string(inter) df = self.storage.load(sym, inter) self.df[sym][inter] = df.copy() - for agg in self.interval[sym]["aggregations"]: + for agg in self.stats.interval[sym]["aggregations"]: # agg_txt = interval_enum_to_string(agg) # agg_txt = f"{interval_txt}+{agg_txt}" df = self.storage.load(sym, int(agg) - 16) @@ -231,7 +236,7 @@ def _setup( # Trim data so start and end dates match between assets and intervals # data_start = pytz.utc.localize(dt.datetime(1970, 1, 1)) # data_end = pytz.utc.localize(dt.datetime.utcnow().replace(microsecond=0, second=0)) - # for i in [self.interval] + self.aggregations: + # for i in [self.stats.interval] + self.aggregations: # for s in self.watch: # start = self.df[i][s].index[0] # end = self.df[i][s].index[-1] @@ -240,7 +245,7 @@ def _setup( # if end < data_end: # data_end = end - # for i in [self.interval] + self.aggregations: + # for i in [self.stats.interval] + self.aggregations: # for s in self.watch: # self.df[i][s] = self.df[i][s].loc[data_start:data_end] @@ -253,8 +258,10 @@ def read_pickle_data(self): :path: Path to the local data file :date_format: The format of the data's timestamps """ - for s in self.interval: - for i in [self.interval[s]["interval"]] + self.interval[s]["aggregations"]: + for s in self.stats.interval: + for i in [self.stats.interval[s]["interval"]] + self.stats.interval[s][ + "aggregations" + ]: df = self.storage.open(s, i).dropna() if df.empty or now() - df.index[-1] > dt.timedelta(days=1): df = self.streamer.fetch_price_history(s, i).dropna() @@ -267,8 +274,10 @@ def read_csv_data(self, path: str, date_format: str = "%Y-%m-%d %H:%M:%S"): :path: Path to the local data file :date_format: The format of the data's timestamps """ - for s in self.interval: - for i in [self.interval[s]["interval"]] + self.interval[s]["aggregations"]: + for s in self.stats.interval: + for i in [self.stats.interval[s]["interval"]] + self.stats.interval[s][ + "aggregations" + ]: i_txt = interval_enum_to_string(i) df = self.read_csv(f"{path}/{s}-{i_txt}.csv").dropna() if df.empty: @@ -299,8 +308,10 @@ def run_backtest(self): # pr.enable() # Reset them - for s in self.interval: - for i in [self.interval[s]["interval"]] + self.interval[s]["aggregations"]: + for s in self.stats.interval: + for i in [self.stats.interval[s]["interval"]] + self.stats.interval[s][ + "aggregations" + ]: self.storage.reset(s, i) self.storage.limit_size = True @@ -309,10 +320,10 @@ def run_backtest(self): common_end = self.common_end counter = {} - for s in self.interval: - inter = self.interval[s]["interval"] + for s in self.stats.interval: + inter = self.stats.interval[s]["interval"] start_index = list(self.df[s][inter].index).index(common_start) - self.interval[s]["start"] = start_index + self.stats.interval[s]["start"] = start_index counter[s] = 0 self.timestamp = common_start.to_pydatetime() @@ -320,8 +331,8 @@ def run_backtest(self): while self.timestamp <= common_end: df_dict = {} - for sym in self.interval: - inter = self.interval[sym]["interval"] + for sym in self.stats.interval: + inter = self.stats.interval[sym]["interval"] if is_freq(self.timestamp, inter): # If data is not in the cache, skip it if self.timestamp in self.df[sym][inter].index: @@ -329,8 +340,8 @@ def run_backtest(self): update = self._update_order_queue() self._update_position_cache(df_dict, new=update, option_update=True) - for sym in self.interval: - inter = self.interval[sym]["interval"] + for sym in self.stats.interval: + inter = self.stats.interval[sym]["interval"] if is_freq(self.timestamp, inter): # If data is not in the cache, skip it @@ -339,9 +350,9 @@ def run_backtest(self): df = self.df[sym][inter].loc[[self.timestamp], :] self.storage.store(s, inter, df, save_pickle=False) # Add data to aggregation queue - for agg in self.interval[sym]["aggregations"]: + for agg in self.stats.interval[sym]["aggregations"]: df = self.df[s][int(agg) - 16].iloc[ - [self.interval[sym]["start"] + counter[sym]], : + [self.stats.interval[sym]["start"] + counter[sym]], : ] self.storage.store(s, agg, df) counter[sym] += 1 diff --git a/harvest/trader/trader.py b/harvest/trader/trader.py index 9b1c1e05..8a9df719 100644 --- a/harvest/trader/trader.py +++ b/harvest/trader/trader.py @@ -84,8 +84,10 @@ def _init_attributes(self): self.server = Server(self) # Initialize the web interface server - self.timezone = tzlocal.get_localzone() - debugger.debug(f"Timezone: {self.timezone}") + self.stats = Stats() # Initialize the stats object + self.stats.timestamp = None + self.stats.timezone = tzlocal.get_localzone() + self.stats.interval = None def _setup_debugger(self, debug): # Set up logger @@ -135,7 +137,7 @@ def start( # Initialize a dict of symbols and the intervals they need to run at self._setup_params(interval, aggregations) - if len(self.interval) == 0: + if len(self.stats.interval) == 0: raise Exception("No securities were added to watchlist") # Initialize the account @@ -144,16 +146,17 @@ def start( self.account["equity"], self.streamer.timestamp ) - self.broker.setup(self.interval, self.main) + self.broker.setup(self.stats, self.main) if self.broker != self.streamer: # Only call the streamer setup if it is a different # instance than the broker otherwise some brokers can fail! - self.streamer.setup(self.interval, self.main) + self.streamer.setup(self.stats, self.main) # Initialize the storage self._storage_init(all_history) for a in self.algo: + a.init(self.stats) a.trader = self a.setup() @@ -194,15 +197,10 @@ def _setup_params(self, interval, aggregations): """ interval = interval_string_to_enum(interval) aggregations = [interval_string_to_enum(a) for a in aggregations] - self.interval = {} - - # Initialize a dict with symbol keys and values indicating - # what data intervals they need. - for sym in self.watchlist_global: - self.interval[sym] = {} - self.interval[sym]["interval"] = interval - self.interval[sym]["aggregations"] = aggregations - + watch_dict = { + sym: {"interval": interval, "aggregations": aggregations} + for sym in self.watchlist_global + } # Update the dict based on parameters specified in Algo class for a in self.algo: a.config() @@ -224,27 +222,29 @@ def _setup_params(self, interval, aggregations): for sym in a.watchlist: # If the algorithm needs data for the symbol at a higher frequency than # it is currently available in the Trader class, update the interval - if sym in self.interval: - cur_interval = self.interval[sym]["interval"] + if sym in watch_dict: + cur_interval = watch_dict[sym]["interval"] if a.interval < cur_interval: - self.interval[sym]["aggregations"].append(cur_interval) - self.interval[sym]["interval"] = a.interval + watch_dict[sym]["aggregations"].append(cur_interval) + watch_dict[sym]["interval"] = a.interval # If symbol is not in global watchlist, simply add it else: - self.interval[sym] = {} - self.interval[sym]["interval"] = a.interval - self.interval[sym]["aggregations"] = a.aggregations + watch_dict[sym] = {} + watch_dict[sym]["interval"] = a.interval + watch_dict[sym]["aggregations"] = a.aggregations # If the algo specifies an aggregation that is currently not set, add it to the # global aggregation list for agg in a.aggregations: - if agg not in self.interval[sym]["aggregations"]: - self.interval[sym]["aggregations"].append(agg) + if agg not in watch_dict[sym]["aggregations"]: + watch_dict[sym]["aggregations"].append(agg) # Remove any duplicates in the dict - for sym in self.interval: - new_agg = list((set(self.interval[sym]["aggregations"]))) - self.interval[sym]["aggregations"] = [] if new_agg is None else new_agg + for sym in watch_dict: + new_agg = list((set(watch_dict[sym]["aggregations"]))) + watch_dict[sym]["aggregations"] = [] if new_agg is None else new_agg + + self.stats.interval = watch_dict def _setup_account(self): """Initializes local cache of account info. @@ -261,10 +261,10 @@ def _storage_init(self, all_history: bool): :all_history: bool : """ - for sym in self.interval: - for inter in [self.interval[sym]["interval"]] + self.interval[sym][ - "aggregations" - ]: + for sym in self.stats.interval: + for inter in [self.stats.interval[sym]["interval"]] + self.stats.interval[ + sym + ]["aggregations"]: start = None if all_history else now() - dt.timedelta(days=3) df = self.streamer.fetch_price_history(sym, inter, start) self.storage.store(sym, inter, df) @@ -276,24 +276,19 @@ def main(self, df_dict): Main loop of the Trader. """ # Periodically refresh access tokens - if ( - self.streamer.timestamp.hour % 12 == 0 - and self.streamer.timestamp.minute == 0 - ): + if self.stats.timestamp.hour % 12 == 0 and self.stats.timestamp.minute == 0: self.streamer.refresh_cred() - self.storage.add_performance_data( - self.account["equity"], self.streamer.timestamp - ) + self.storage.add_performance_data(self.account["equity"], self.stats.timestamp) # Save the data locally for sym in df_dict: - self.storage.store(sym, self.interval[sym]["interval"], df_dict[sym]) + self.storage.store(sym, self.stats.interval[sym]["interval"], df_dict[sym]) # Aggregate the data to other intervals for sym in df_dict: - for agg in self.interval[sym]["aggregations"]: - self.storage.aggregate(sym, self.interval[sym]["interval"], agg) + for agg in self.stats.interval[sym]["aggregations"]: + self.storage.aggregate(sym, self.stats.interval[sym]["interval"], agg) # If an order was processed, fetch the latest position info. # Otherwise, calculate current positions locally @@ -304,7 +299,7 @@ def main(self, df_dict): new_algo = [] for a in self.algo: - if not is_freq(self.streamer.timestamp, a.interval): + if not is_freq(self.stats.timestamp, a.interval): new_algo.append(a) continue try: @@ -411,11 +406,13 @@ def _update_position_cache(self, df_dict, new=False, option_update=False): def _fetch_account_data(self): pos = self.broker.fetch_stock_positions() - self.stock_positions = [p for p in pos if p["symbol"] in self.interval] + self.stock_positions = [p for p in pos if p["symbol"] in self.stats.interval] pos = self.broker.fetch_option_positions() - self.option_positions = [p for p in pos if p["base_symbol"] in self.interval] + self.option_positions = [ + p for p in pos if p["base_symbol"] in self.stats.interval + ] pos = self.broker.fetch_crypto_positions() - self.crypto_positions = [p for p in pos if p["symbol"] in self.interval] + self.crypto_positions = [p for p in pos if p["symbol"] in self.stats.interval] ret = self.broker.fetch_account() self.account = ret @@ -436,7 +433,7 @@ def buy(self, symbol: str, quantity: int, in_force: str, extended: bool): if symbol_type(symbol) == "OPTION": price = self.streamer.fetch_option_market_data(symbol)["price"] else: - price = self.storage.load(symbol, self.interval[symbol]["interval"])[ + price = self.storage.load(symbol, self.stats.interval[symbol]["interval"])[ symbol ]["close"][-1] @@ -462,7 +459,7 @@ def buy(self, symbol: str, quantity: int, in_force: str, extended: bool): debugger.debug("BUY failed") return None self.order_queue.append(ret) - debugger.debug(f"BUY: {self.streamer.timestamp}, {symbol}, {quantity}") + debugger.debug(f"BUY: {self.stats.timestamp}, {symbol}, {quantity}") return ret @@ -478,7 +475,7 @@ def sell(self, symbol: str, quantity: int, in_force: str, extended: bool): if symbol_type(symbol) == "OPTION": price = self.streamer.fetch_option_market_data(symbol)["price"] else: - price = self.storage.load(symbol, self.interval[symbol]["interval"])[ + price = self.storage.load(symbol, self.stats.interval[symbol]["interval"])[ symbol ]["close"][-1] @@ -489,7 +486,7 @@ def sell(self, symbol: str, quantity: int, in_force: str, extended: bool): debugger.debug("SELL failed") return None self.order_queue.append(ret) - debugger.debug(f"SELL: {self.streamer.timestamp}, {symbol}, {quantity}") + debugger.debug(f"SELL: {self.stats.timestamp}, {symbol}, {quantity}") return ret # ================ Helper Functions ====================== diff --git a/harvest/utils.py b/harvest/utils.py index 78677b74..f5edfe10 100644 --- a/harvest/utils.py +++ b/harvest/utils.py @@ -51,6 +51,37 @@ def interval_string_to_enum(str_interval: str): raise ValueError(f"Invalid interval string {str_interval}") +class Stats: + def __init__(self, timestamp=None, timezone=None, interval=None): + self._timestamp = timestamp + self._timezone = timezone + self._interval = interval + + @property + def timestamp(self): + return self._timestamp + + @timestamp.setter + def timestamp(self, value): + self._timestamp = value + + @property + def timezone(self): + return self._timezone + + @timezone.setter + def timezone(self, value): + self._timezone = value + + @property + def interval(self): + return self._interval + + @interval.setter + def interval(self, value): + self._interval = value + + def interval_enum_to_string(enum): try: name = enum.name diff --git a/tests/test_algo.py b/tests/test_algo.py index 913acab2..a2acddbf 100644 --- a/tests/test_algo.py +++ b/tests/test_algo.py @@ -76,7 +76,7 @@ def config(self): s.tick() self.assertListEqual( - t.interval["A"]["aggregations"], [Interval.MIN_15, Interval.DAY_1] + t.stats.interval["A"]["aggregations"], [Interval.MIN_15, Interval.DAY_1] ) def test_rsi(self): diff --git a/tests/test_api.py b/tests/test_api.py index 2f6a3d0e..a4439dc7 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -27,8 +27,8 @@ def test_timeout(self): stream.fetch_account = lambda: None stream.fetch_price_history = lambda x, y, z: pd.DataFrame() stream.fetch_account = lambda: {"cash": 100, "equity": 100} - t = PaperTrader(stream) - stream.trader = t + t = PaperTrader(stream, debug=True) + stream.trader_main = t.main t.set_symbol(["A", "B"]) @@ -43,6 +43,10 @@ def test_timeout(self): # Save the last datapoint of B a_cur = t.storage.load("A", Interval.MIN_1) b_cur = t.storage.load("B", Interval.MIN_1) + print("test0", b_cur) + + # Manually advance timestamp of streamer + stream.timestamp = stream.timestamp + dt.timedelta(minutes=1) # Only send data for A data = gen_data("A", 1) @@ -63,6 +67,7 @@ def test_timeout(self): t.storage.load("A", Interval.MIN_1)["A"]["close"][-1], ) # Check if B has been set to 0 + print("Test", t.storage.load("B", Interval.MIN_1)["B"]) self.assertEqual( b_cur["B"]["close"][-1], t.storage.load("B", Interval.MIN_1)["B"]["close"][-2], @@ -133,84 +138,113 @@ def test_exceptions(self): self.assertEqual(api.create_secret("I dont exists"), False) try: - api.fetch_price_history('A', Interval.MIN_1, now(), now()) + api.fetch_price_history("A", Interval.MIN_1, now(), now()) self.assertTrue(False) except NotImplementedError as e: - self.assertEqual(str(e), "API does not support this streamer method: `fetch_price_history`.") + self.assertEqual( + str(e), + "API does not support this streamer method: `fetch_price_history`.", + ) try: api.fetch_chain_info("A") self.assertTrue(False) except NotImplementedError as e: - self.assertEqual(str(e), "API does not support this streamer method: `fetch_chain_info`.") + self.assertEqual( + str(e), "API does not support this streamer method: `fetch_chain_info`." + ) try: api.fetch_chain_data("A") self.assertTrue(False) except NotImplementedError as e: - self.assertEqual(str(e), "API does not support this streamer method: `fetch_chain_data`.") + self.assertEqual( + str(e), "API does not support this streamer method: `fetch_chain_data`." + ) try: api.fetch_option_market_data("A") self.assertTrue(False) except NotImplementedError as e: - self.assertEqual(str(e), "API does not support this streamer method: `fetch_option_market_data`.") + self.assertEqual( + str(e), + "API does not support this streamer method: `fetch_option_market_data`.", + ) try: api.fetch_account() self.assertTrue(False) except NotImplementedError as e: - self.assertEqual(str(e), "API does not support this broker method: `fetch_account`.") + self.assertEqual( + str(e), "API does not support this broker method: `fetch_account`." + ) try: api.fetch_stock_order_status(0) self.assertTrue(False) except NotImplementedError as e: - self.assertEqual(str(e), "API does not support this broker method: `fetch_stock_order_status`.") + self.assertEqual( + str(e), + "API does not support this broker method: `fetch_stock_order_status`.", + ) try: api.fetch_option_order_status(0) self.assertTrue(False) except NotImplementedError as e: - self.assertEqual(str(e), "API does not support this broker method: `fetch_option_order_status`.") + self.assertEqual( + str(e), + "API does not support this broker method: `fetch_option_order_status`.", + ) try: api.fetch_crypto_order_status(0) self.assertTrue(False) except NotImplementedError as e: - self.assertEqual(str(e), "API does not support this broker method: `fetch_crypto_order_status`.") + self.assertEqual( + str(e), + "API does not support this broker method: `fetch_crypto_order_status`.", + ) try: api.order_stock_limit("buy", "A", 5, 7) self.assertTrue(False) except NotImplementedError as e: - self.assertEqual(str(e), "API does not support this broker method: `order_stock_limit`.") + self.assertEqual( + str(e), "API does not support this broker method: `order_stock_limit`." + ) try: api.order_crypto_limit("buy", "@A", 5, 7) self.assertTrue(False) except NotImplementedError as e: - self.assertEqual(str(e), "API does not support this broker method: `order_crypto_limit`.") + self.assertEqual( + str(e), "API does not support this broker method: `order_crypto_limit`." + ) try: api.order_option_limit("buy", "A", 5, 7, "call", now(), 8) self.assertTrue(False) except NotImplementedError as e: - self.assertEqual(str(e), "API does not support this broker method: `order_option_limit`.") + self.assertEqual( + str(e), "API does not support this broker method: `order_option_limit`." + ) try: - api.buy('A', -1, 0) + api.buy("A", -1, 0) self.assertTrue(False) except NotImplementedError as e: - self.assertEqual(str(e), "API does not support this broker method: `order_stock_limit`.") + self.assertEqual( + str(e), "API does not support this broker method: `order_stock_limit`." + ) try: - api.sell('A', -1, 0) + api.sell("A", -1, 0) self.assertTrue(False) except NotImplementedError as e: - self.assertEqual(str(e), "API does not support this broker method: `order_stock_limit`.") - - + self.assertEqual( + str(e), "API does not support this broker method: `order_stock_limit`." + ) def test_base_cases(self): api = API() diff --git a/tests/test_api_dummy.py b/tests/test_api_dummy.py index ec45a21f..5d8cee2c 100644 --- a/tests/test_api_dummy.py +++ b/tests/test_api_dummy.py @@ -39,7 +39,8 @@ def test_setup(self): "C": {"interval": Interval.MIN_1, "agg,regations": []}, "@D": {"interval": Interval.MIN_1, "agg,regations": []}, } - dummy.setup(interval) + stats = Stats(interval=interval) + dummy.setup(stats) self.assertEqual(dummy.interval, interval) @@ -51,7 +52,9 @@ def test_get_stock_price(self): "C": {"interval": Interval.MIN_1, "agg,regations": []}, "@D": {"interval": Interval.MIN_1, "agg,regations": []}, } - dummy.setup(interval) + stats = Stats(interval=interval) + dummy.setup(stats) + d = dummy.fetch_latest_stock_price() self.assertEqual(len(d), 3) @@ -63,7 +66,9 @@ def test_get_crypto_price(self): "C": {"interval": Interval.MIN_1, "agg,regations": []}, "@D": {"interval": Interval.MIN_1, "agg,regations": []}, } - dummy.setup(interval) + stats = Stats(interval=interval) + dummy.setup(stats) + d = dummy.fetch_latest_crypto_price() self.assertTrue("@D" in d) self.assertEqual(d["@D"].shape, (1, 5)) diff --git a/tests/test_api_paper.py b/tests/test_api_paper.py index 6b2c9671..73e67ed2 100644 --- a/tests/test_api_paper.py +++ b/tests/test_api_paper.py @@ -36,7 +36,8 @@ def test_dummy_account(self): def test_buy_order_limit(self): dummy = PaperBroker() interval = {"A": {"interval": Interval.MIN_1, "aggregations": []}} - dummy.setup(interval) + stats = Stats(interval=interval) + dummy.setup(stats) order = dummy.order_stock_limit("buy", "A", 5, 50000) self.assertEqual(order["type"], "STOCK") self.assertEqual(order["id"], 0) @@ -54,7 +55,8 @@ def test_buy_order_limit(self): def test_buy(self): dummy = PaperBroker() interval = {"A": {"interval": Interval.MIN_1, "aggregations": []}} - dummy.setup(interval) + stats = Stats(interval=interval) + dummy.setup(stats) order = dummy.buy("A", 5, 1e5) self.assertEqual(order["type"], "STOCK") self.assertEqual(order["id"], 0) @@ -72,8 +74,11 @@ def test_buy(self): def test_sell_order_limit(self): directory = pathlib.Path(__file__).parent.resolve() dummy = PaperBroker(str(directory) + "/../dummy_account.yaml") + interval = {"A": {"interval": Interval.MIN_1, "aggregations": []}} - dummy.setup(interval) + stats = Stats(interval=interval) + dummy.setup(stats) + order = dummy.order_stock_limit("sell", "A", 2, 50000) self.assertEqual(order["type"], "STOCK") self.assertEqual(order["id"], 0) @@ -91,8 +96,11 @@ def test_sell_order_limit(self): def test_sell(self): directory = pathlib.Path(__file__).parent.resolve() dummy = PaperBroker(str(directory) + "/../dummy_account.yaml") + interval = {"A": {"interval": Interval.MIN_1, "aggregations": []}} - dummy.setup(interval) + stats = Stats(interval=interval) + dummy.setup(stats) + order = dummy.sell("A", 2) self.assertEqual(order["type"], "STOCK") self.assertEqual(order["id"], 0) @@ -109,8 +117,11 @@ def test_sell(self): def test_order_option_limit(self): dummy = PaperBroker() + interval = {"A": {"interval": Interval.MIN_1, "aggregations": []}} - dummy.setup(interval) + stats = Stats(interval=interval) + dummy.setup(stats) + exp_date = dt.datetime(2021, 11, 14) + dt.timedelta(hours=5) order = dummy.order_option_limit( "buy", "A", 5, 50000, "OPTION", exp_date, 50001 diff --git a/tests/test_api_yahoo.py b/tests/test_api_yahoo.py index 46994d1a..3a623fd2 100644 --- a/tests/test_api_yahoo.py +++ b/tests/test_api_yahoo.py @@ -23,8 +23,9 @@ def test_setup(self): "SPY": {"interval": Interval.MIN_15, "aggregations": []}, "AAPL": {"interval": Interval.MIN_1, "aggregations": []}, } + stats = Stats(interval=interval) + yh.setup(stats) - yh.setup(interval) self.assertEqual(yh.poll_interval, Interval.MIN_1) self.assertListEqual([s for s in yh.interval], ["SPY", "AAPL"]) @@ -42,8 +43,9 @@ def test_main(df): self.assertEqual(df["@BTC"].columns[0][0], "@BTC") yh = YahooStreamer() - watch = ["SPY", "AAPL", "@BTC"] - yh.setup(interval, test_main) + stats = Stats(interval=interval) + yh.setup(stats, test_main) + yh.main() def test_main_single(self): @@ -54,22 +56,29 @@ def test_main(df): self.assertEqual(df["SPY"].columns[0][0], "SPY") yh = YahooStreamer() - yh.setup(interval, test_main) + stats = Stats(interval=interval) + yh.setup(stats, test_main) + yh.main() def test_chain_info(self): - t = PaperTrader() yh = YahooStreamer() + interval = {"LMND": {"interval": Interval.MIN_1, "aggregations": []}} - yh.setup(interval) + stats = Stats(interval=interval) + yh.setup(stats) + info = yh.fetch_chain_info("LMND") self.assertGreater(len(info["exp_dates"]), 0) def test_chain_data(self): - t = PaperTrader() + yh = YahooStreamer() + interval = {"LMND": {"interval": Interval.MIN_1, "aggregations": []}} - yh.setup(interval) + stats = Stats(interval=interval) + yh.setup(stats) + dates = yh.fetch_chain_info("LMND")["exp_dates"] data = yh.fetch_chain_data("LMND", dates[0]) self.assertGreater(len(data), 0) diff --git a/tests/test_tester.py b/tests/test_tester.py index b0bcc67a..940c1e20 100644 --- a/tests/test_tester.py +++ b/tests/test_tester.py @@ -42,9 +42,9 @@ def test_check_aggregation(self): t.start("1MIN", ["1DAY"], period="1DAY") minutes = list(t.storage.load("A", Interval.MIN_1)["A"]["close"])[-200:] - days_agg = list( - t.storage.load("A", int(Interval.DAY_1) - 16)["A"]["close"] - )[-200:] + days_agg = list(t.storage.load("A", int(Interval.DAY_1) - 16)["A"]["close"])[ + -200: + ] self.assertListEqual(minutes, days_agg)