diff --git a/README.md b/README.md index 967de5bd..bcf9911d 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ ![Header](docs/banner.png)
-Harvest is a Python framework providing a **simple** and **flexible** framework for algorithmic trading. Visit Harvest's [**website**](https://tfukaza.github.io/harvest-website/) for details, tutorials, and documentation. +Harvest is a simple yet flexible Python framework for algorithmic trading. Paper trade and live trade stocks, cryptos, and options![^1][^2] Visit Harvest's [**website**](https://tfukaza.github.io/harvest-website/) for details, tutorials, and documentation.
@@ -77,3 +77,6 @@ Currently looking for... - Many of the brokers were also not designed to be used for algo-trading. Excessive access to their API can result in your account getting locked. - Tutorials and documentation solely exist to provide technical references of the code. They are not recommendations of any specific securities or strategies. - Use Harvest at your own responsibility. Developers of Harvest take no responsibility for any financial losses you incur by using Harvest. By using Harvest, you certify you understand that Harvest is a software in early development and may contain bugs and unexpected behaviors. + +[^1]: What assets you can trade depends on the broker you are using. +[^2]: Backtesting is also available, but it is not supported for options. diff --git a/docs/dev.md b/docs/dev.md index 3207ae8d..e92d77d3 100644 --- a/docs/dev.md +++ b/docs/dev.md @@ -1,7 +1,22 @@ -# Overview +## The Harvest Workflow -### Harvest architecture -![Harvest architecture](harvest_architecture.png) +Because Harvest is an extensive framework it can be hard to understand how the system works at times. This document server to provide a high-level overview of just how Harvests works after a user starts the trader. -### Harvest Flow on Trader `start()` -![Harvest start flow](harvest_start_flow.png) \ No newline at end of file +## Fetching Data + +After the user starts the trader Harvest will fetch data from the streamer and update its storage on interval. + +![Fetching Data Workflow](fetch-data.png) + +1. First Harvest will run the streamer on the specified interval. Once the data has been collected, the streamer will call a callback or hook function that will pass operation back to the trader. In this callback function the streamer will return the latest OHLC data for the assets specified by the trader. +2. In this callback, the trader will update the storage with the latest data and the will run each algorithm. + +## Running Algorithms + +After data is fetched, the algorithms are run linearly. + +![Running Algorithm Workflow](run-algo.png) + +1. The algorithm the user created will user functions provided in the `BaseAlgo` class which communicate with the Trader. +2. Typically the user's algorithms will first ask for data on the assets they specified which will be stored in the Storage. +3. After that the user's algoirthms will decided when to buy or sell assets based on the data they got from the Storage. This will leverage the Broker. \ No newline at end of file diff --git a/docs/fetch-data.png b/docs/fetch-data.png new file mode 100644 index 00000000..30a4eb65 Binary files /dev/null and b/docs/fetch-data.png differ diff --git a/docs/harvest_architecture.png b/docs/harvest_architecture.png deleted file mode 100644 index 239acb3d..00000000 Binary files a/docs/harvest_architecture.png and /dev/null differ diff --git a/docs/harvest_start_flow.png b/docs/harvest_start_flow.png deleted file mode 100644 index e14cd4a5..00000000 Binary files a/docs/harvest_start_flow.png and /dev/null differ diff --git a/docs/run-algo.png b/docs/run-algo.png new file mode 100644 index 00000000..c01f09f7 Binary files /dev/null and b/docs/run-algo.png differ diff --git a/examples/em_alpaca.py b/examples/em_alpaca.py index 8213b40b..9ee6f518 100644 --- a/examples/em_alpaca.py +++ b/examples/em_alpaca.py @@ -1,3 +1,4 @@ +# HARVEST_SKIP # Builtin imports import logging import datetime as dt diff --git a/examples/em_kraken.py b/examples/em_kraken.py new file mode 100644 index 00000000..909dd88b --- /dev/null +++ b/examples/em_kraken.py @@ -0,0 +1,100 @@ +# HARVEST_SKIP +# Builtin imports +import logging +import datetime as dt + +# Harvest imports +from harvest.algo import BaseAlgo +from harvest.trader import LiveTrader +from harvest.api.kraken import Kraken +from harvest.storage.csv_storage import CSVStorage + +# Third-party imports +import pandas as pd +import matplotlib.pyplot as plt +import mplfinance as mpf + + +class EMAlgo(BaseAlgo): + def setup(self): + now = dt.datetime.now() + logging.info(f"EMAlgo.setup ran at: {now}") + + def init_ticker(ticker): + fig = mpf.figure() + ax1 = fig.add_subplot(2, 1, 1) + ax2 = fig.add_subplot(3, 1, 3) + + return { + ticker: { + "initial_price": None, "ohlc": pd.DataFrame(), + "fig": fig, + "ax1": ax1, + "ax2": ax2 + } + } + + self.tickers = {} + # self.tickers.update(init_ticker("@BTC")) + self.tickers.update(init_ticker("@DOGE")) + + def main(self): + now = dt.datetime.now() + logging.info(f"EMAlgo.main ran at: {now}") + + if now - now.replace(hour=0, minute=0, second=0, microsecond=0) <= dt.timedelta( + seconds=60 + ): + logger.info(f"It's a new day! Clearning OHLC caches!") + for ticker_value in self.tickers.values(): + ticker_value["ohlc"] = pd.DataFrame() + + for ticker, ticker_value in self.tickers.items(): + current_price = self.get_asset_price(ticker) + current_ohlc = self.get_asset_candle(ticker) + if ticker_value["initial_price"] is None: + ticker_value["initial_price"] = current_price + + if current_ohlc.empty: + logging.warn(f"{ticker}'s get_asset_candle_list returned an empty list.") + return + + ticker_value["ohlc"] = ticker_value["ohlc"].append(current_ohlc) + + self.process_ticker(ticker, ticker_value, current_price) + + def process_ticker(self, ticker, ticker_data, current_price): + initial_price = ticker_data["initial_price"] + ohlc = ticker_data["ohlc"] + + # Calculate the price change + delta_price = current_price - initial_price + + # Print stock info + logging.info(f"{ticker} current price: ${current_price}") + logging.info(f"{ticker} price change: ${delta_price}") + + # Update the OHLC graph + ticker_data['ax1'].clear() + ticker_data['ax2'].clear() + mpf.plot(ohlc, ax=ticker_data['ax1'], volume=ticker_data['ax2'], type="candle") + plt.pause(3) + + +if __name__ == "__main__": + # Store the OHLC data in a folder called `em_storage` with each file stored as a csv document + csv_storage = CSVStorage(save_dir="em_storage") + # Our streamer and broker will be Alpaca. My secret keys are stored in `alpaca_secret.yaml` + kraken = Kraken( + path="accounts/kraken-secret.yaml" + ) + em_algo = EMAlgo() + trader = LiveTrader(streamer=kraken, broker=kraken, storage=csv_storage, debug=True) + + # trader.set_symbol("@BTC") + trader.set_symbol("@DOGE") + trader.set_algo(em_algo) + mpf.show() + + # Update every minute + trader.start("1MIN", all_history=False) diff --git a/examples/em_polygon.py b/examples/em_polygon.py index 4f6430f8..95ebcda6 100644 --- a/examples/em_polygon.py +++ b/examples/em_polygon.py @@ -1,3 +1,4 @@ +# HARVEST_SKIP # Builtin imports import logging import datetime as dt diff --git a/harvest/algo.py b/harvest/algo.py index de12af96..18610722 100644 --- a/harvest/algo.py +++ b/harvest/algo.py @@ -141,7 +141,7 @@ def sell( debugger.debug(f"Algo SELL: {symbol}, {quantity}") return self.trader.sell(symbol, quantity, in_force, extended) - + def sell_all_options(self, symbol: str = None, in_force: str = "gtc"): """Sells all options of a stock @@ -163,8 +163,8 @@ def sell_all_options(self, symbol: str = None, in_force: str = "gtc"): for s in symbols: debugger.debug(f"Algo SELL OPTION: {s}") quantity = self.get_asset_quantity(s) - ret.append(self.trader.sell_option(s, quantity, in_force)) - + ret.append(self.trader.sell(s, quantity, in_force, True)) + return ret # def buy_option(self, symbol: str, quantity: int = None, in_force: str = "gtc"): @@ -522,9 +522,8 @@ def get_asset_quantity(self, symbol: str = None) -> float: """ if symbol is None: symbol = self.watchlist[0] - - return self.trader.get_asset_quantity(symbol, exclude_pending_sell=True) + return self.trader.get_asset_quantity(symbol, exclude_pending_sell=True) def get_asset_cost(self, symbol: str = None) -> float: """Returns the average cost of a specified asset. @@ -771,7 +770,9 @@ def get_datetime(self): :returns: The current date and time as a datetime object """ - return datetime_utc_to_local(self.trader.timestamp, self.trader.timezone) + return datetime_utc_to_local( + self.trader.streamer.timestamp, self.trader.timezone + ) def get_option_position_quantity(self, symbol: str = None) -> bool: """Returns the number of types of options held for a stock. diff --git a/harvest/api/_base.py b/harvest/api/_base.py index 7099cf51..e0c86b8f 100644 --- a/harvest/api/_base.py +++ b/harvest/api/_base.py @@ -22,7 +22,7 @@ class API: Attributes :interval_list: A list of supported intervals. - :exchange: The market the API trades on. Ignored if the API is not a broker. + :exchange: The market the API trades on. Ignored if the API is not a broker. """ interval_list = [ @@ -53,11 +53,6 @@ def __init__(self, path: str = None): :path: path to the YAML file containing credentials to communicate with the API. If not specified, defaults to './secret.yaml' """ - self.trader = ( - None # Allows broker to handle the case when runs without a trader - ) - - self.run_count = 0 if path is None: path = "./secret.yaml" @@ -69,14 +64,15 @@ def __init__(self, path: str = None): with open(path, "r") as stream: self.config = yaml.safe_load(stream) - self.run_count = 0 self.timestamp = now() def create_secret(self, path: str): """ This method is called when the yaml file with credentials is not found.""" - raise Exception(f"{path} was not found.") + # raise Exception(f"{path} was not found.") + debugger.warning(f"Assuming API does not need account information.") + return False def refresh_cred(self): """ @@ -85,14 +81,15 @@ def refresh_cred(self): """ debugger.info(f"Refreshing credentials for {type(self).__name__}.") - def setup(self, interval: Dict, trader=None, trader_main=None) -> None: + def setup(self, interval: Dict, trader_main=None) -> None: """ This function is called right before the algorithm begins, and initializes several runtime parameters like the symbols to watch and what interval data is needed. + + :trader_main: A callback function to the trader which will pass the data to the algorithms. """ - self.trader = trader self.trader_main = trader_main min_interval = None @@ -182,12 +179,12 @@ def main(self): df_dict = {} for sym in self.interval: inter = self.interval[sym]["interval"] - if is_freq(self.timestamp, inter): - n = self.timestamp + if is_freq(harvest_timestamp, inter): + n = harvest_timestamp latest = self.fetch_price_history( sym, inter, n - interval_to_timedelta(inter) * 2, n ) - debugger.debug(f"Price fetch returned: \n{latest}") + debugger.debug(f"{sym} price fetch returned: {latest}") if latest is None or latest.empty: continue df_dict[sym] = latest.iloc[-1] @@ -229,29 +226,40 @@ def wrapper(*args, **kwargs): return wrapper def _run_once(func): - """ """ + """ + Wrapper to only allows wrapped functions to be run once. + + :func: Function to wrap. + :returns: The return of the inputted function if it has not been run before and None otherwise. + """ + + ran = False def wrapper(*args, **kwargs): - self = args[0] - if self.run_count == 0: - self.run_count += 1 - return func(args, kwargs) + nonlocal ran + if not ran: + ran = True + return func(*args, **kwargs) return None return wrapper # -------------- Streamer methods -------------- # + def get_current_time(self): + return now() + def fetch_price_history( self, symbol: str, interval: Interval, start: dt.datetime = None, end: dt.datetime = None, - ): + ) -> pd.DataFrame: """ Fetches historical price data for the specified asset and period - using the API. + using the API. The first row is the earliest entry and the last + row is the latest entry. :param symbol: The stock/crypto to get data for. :param interval: The interval of requested historical data. @@ -306,10 +314,10 @@ def fetch_option_market_data(self, symbol: str): raise NotImplementedError( f"{type(self).__name__} does not support this streamer method: `fetch_option_market_data`." ) - + def fetch_market_hours(self, date: datetime.date): """ - Returns the market hours for a given day. + Returns the market hours for a given day. Hours are based on the exchange specified in the class's 'exchange' attribute. :returns: A dictionary with the following keys and values: @@ -528,7 +536,7 @@ def order_stock_limit( Raises an exception if order fails. """ raise NotImplementedError( - f"{type(self).__name__} does not support this broker method: `order_limit`." + f"{type(self).__name__} does not support this broker method: `order_stock_limit`." ) def order_crypto_limit( @@ -556,7 +564,7 @@ def order_crypto_limit( Raises an exception if order fails. """ raise NotImplementedError( - f"{type(self).__name__} does not support this broker method: `order_limit`." + f"{type(self).__name__} does not support this broker method: `order_crypto_limit`." ) def order_option_limit( @@ -596,7 +604,12 @@ def order_option_limit( # These do not need to be re-implemented in a subclass def buy( - self, symbol: str, quantity: int, limit_price: float, in_force: str = "gtc", extended: bool = False + self, + symbol: str, + quantity: int, + limit_price: float, + in_force: str = "gtc", + extended: bool = False, ): """ Buys the specified asset. @@ -653,7 +666,7 @@ def sell( :returns: The result of order_limit(). Returns None if there is an issue with the parameters. """ - + debugger.debug(f"{type(self).__name__} ordered a sell of {quantity} {symbol}") typ = symbol_type(symbol) @@ -707,8 +720,8 @@ def sell( # if total_price >= buy_power: # debugger.warning( - # "Not enough buying power.\n" + - # f"Total price ({price} * {quantity} * 1.05 = {limit_price*quantity}) exceeds buying power {buy_power}.\n" + + # "Not enough buying power.\n" + + # f"Total price ({price} * {quantity} * 1.05 = {limit_price*quantity}) exceeds buying power {buy_power}.\n" + # "Reduce purchase quantity or increase buying power." # ) @@ -817,8 +830,8 @@ def __init__(self, path: str = None): self.block_queue = {} self.first = True - def setup(self, interval: Dict, trader=None, trader_main=None) -> None: - super().setup(interval, trader, trader_main) + def setup(self, interval: Dict, trader_main=None) -> None: + super().setup(interval, trader_main) self.blocker = {} def start(self): @@ -878,15 +891,15 @@ def timeout(self): self.flush() def flush(self): - # For missing data, repeat the existing one + # For missing data, return a OHLC with all zeroes. self.block_lock.acquire() for n in self.needed: - data = ( - self.trader.storage.load(n, self.interval[n]["interval"]) - .iloc[[-1]] - .copy() + data = pd.DataFrame( + {"open": 0, "high": 0, "low": 0, "close": 0, "volume": 0}, + index=[self.timestamp], ) - data.index = [self.timestamp] + + data.columns = pd.MultiIndex.from_product([[n], data.columns]) self.block_queue[n] = data self.block_lock.release() self.trader_main(self.block_queue) diff --git a/harvest/api/alpaca.py b/harvest/api/alpaca.py index c30f9b29..f1a52f24 100644 --- a/harvest/api/alpaca.py +++ b/harvest/api/alpaca.py @@ -76,8 +76,8 @@ async def update_data(self, bar): else: self.data_lock.release() - def setup(self, interval: Dict, trader=None, trader_main=None): - super().setup(interval, trader, trader_main) + def setup(self, interval: Dict, trader_main=None): + super().setup(interval, trader_main) self.watch_stock = [] self.watch_crypto = [] diff --git a/harvest/api/dummy.py b/harvest/api/dummy.py index 87083230..e30a6f63 100644 --- a/harvest/api/dummy.py +++ b/harvest/api/dummy.py @@ -27,21 +27,22 @@ class DummyStreamer(API): Interval.HR_1, Interval.DAY_1, ] - default_now = dt.datetime(year=2000, month=1, day=1, hour=0, minute=0) + + default_timestamp = dt.datetime(year=2000, month=1, day=1, hour=0, minute=0) def __init__( self, - path: str = None, - now: dt.datetime = default_now, + timestamp: dt.datetime = default_timestamp, realistic_times: bool = False, ): - self.trader = None + + super().__init__(None) + self.trader_main = None self.realistic_times = realistic_times # Set the current time - self._set_now(now) - self.timestamp = self.now + self._set_timestamp(timestamp) # Used so `fetch_price_history` can work without running `setup` self.interval = self.interval_list[0] # Store random values and generates for each asset tot make `fetch_price_history` fixed @@ -65,7 +66,7 @@ def fetch_latest_stock_price(self) -> Dict[str, pd.DataFrame]: """ results = {} - today = self.now + today = self.timestamp last = today - dt.timedelta(days=3) for symbol in self.interval: @@ -83,7 +84,7 @@ def fetch_latest_crypto_price(self) -> Dict[str, pd.DataFrame]: """ results = {} - today = self.now + today = self.timestamp last = today - dt.timedelta(days=3) for symbol in self.interval: if is_crypto(symbol): @@ -94,6 +95,9 @@ def fetch_latest_crypto_price(self) -> Dict[str, pd.DataFrame]: # -------------- Streamer methods -------------- # + def get_current_time(self): + return self.timestamp + def fetch_price_history( self, symbol: str, @@ -109,14 +113,14 @@ def fetch_price_history( Interval.MIN_15, Interval.MIN_30, ]: - start = self.now - dt.timedelta(days=2) + start = self.timestamp - dt.timedelta(days=2) elif interval == Interval.HR_1: - start = self.now - dt.timedelta(days=14) + start = self.timestamp - dt.timedelta(days=14) else: - start = self.now - dt.timedelta(days=365) + start = self.timestamp - dt.timedelta(days=365) if end is None: - end = self.now + end = self.timestamp if start.tzinfo is None or start.tzinfo.utcoffset(start) is None: start = pytz.utc.localize(start) @@ -168,7 +172,7 @@ def fetch_price_history( self.randomness[symbol + "_rng"] = rng # The inital price is arbitarly calculated from the first change in price - start_price = 1000 * (self.randomness[symbol][0] + 0.51) + start_price = 1000 * (self.randomness[symbol][0] + 0.501) times = [] current_time = start @@ -218,7 +222,6 @@ def fetch_price_history( results.columns = pd.MultiIndex.from_product([[symbol], results.columns]) results = aggregate_df(results, interval) - return results # TODO: Generate dummy option data @@ -227,7 +230,7 @@ def fetch_option_market_data(self, symbol: str): # This is a placeholder so Trader doesn't crash message = hashlib.sha256() message.update(symbol.encode("utf-8")) - message.update(str(self.now).encode("utf-8")) + message.update(str(self.timestamp).encode("utf-8")) hsh = message.digest() price = int.from_bytes(hsh[:4], "big") / (2 ** 32) price = (price + 1) * 1.5 @@ -267,16 +270,16 @@ def fetch_option_market_data(self, symbol: str): # ------------- Helper methods ------------- # - def _set_now(self, current_datetime: dt.datetime) -> None: + def _set_timestamp(self, current_datetime: dt.datetime) -> None: if ( current_datetime.tzinfo is None or current_datetime.tzinfo.utcoffset(current_datetime) is None ): - self.now = pytz.utc.localize(current_datetime) + self.timestamp = pytz.utc.localize(current_datetime) else: - self.now = current_datetime + self.timestamp = current_datetime def tick(self) -> None: - self.now += interval_to_timedelta(self.poll_interval) + self.timestamp += interval_to_timedelta(self.poll_interval) if not self.trader_main == None: self.main() diff --git a/harvest/api/kraken.py b/harvest/api/kraken.py index a6a8cb91..4b01cba0 100644 --- a/harvest/api/kraken.py +++ b/harvest/api/kraken.py @@ -14,16 +14,10 @@ class Kraken(API): - interval_list = [ - Interval.MIN_1, - Interval.MIN_5, - Interval.MIN_15, - Interval.MIN_30, - Interval.HR_1, - Interval.DAY_1, - ] + interval_list = [Interval.MIN_1, Interval.MIN_5, Interval.HR_1, Interval.DAY_1] + crypto_ticker_to_kraken_names = { - "BTC": "XXBT", + "BTC": "XXBTZ", "ETH": "XETH", "ADA": "ADA", "USDT": "USDT", @@ -113,19 +107,24 @@ class Kraken(API): "OXT": "OXT", } + kraken_names_to_crypto_ticker = { + v: k for k, v in crypto_ticker_to_kraken_names.items() + } + def __init__(self, path: str = None): super().__init__(path) self.api = krakenex.API(self.config["api_key"], self.config["secret_key"]) - def setup(self, watch: List[str], interval: str, trader=None, trader_main=None): + def setup(self, interval: Dict, trader_main=None): + super().setup(interval, trader_main) self.watch_crypto = [] - if is_crypto(s): - self.watch_crypto.append(s) - else: - debugger.error("Kraken does not support stocks.") + for sym in interval: + if is_crypto(sym): + self.watch_crypto.append(sym) + else: + debugger.warning(f"Kraken does not support stocks. Ignoring {sym}.") self.option_cache = {} - super().setup(watch, interval, interval, trader, trader_main) def exit(self): self.option_cache = {} @@ -139,10 +138,13 @@ def main(self): @API._exception_handler def fetch_latest_crypto_price(self): dfs = {} - for symbol in self.watch_cryptos: + for symbol in self.watch_crypto: dfs[symbol] = self.fetch_price_history( - symbol, self.interval, now() - dt.timedelta(days=7), now() - ).iloc[[0]] + symbol, + self.interval[symbol]["interval"], + now() - dt.timedelta(days=7), + now(), + ).iloc[[-1]] return dfs # -------------- Streamer methods -------------- # @@ -151,7 +153,7 @@ def fetch_latest_crypto_price(self): def fetch_price_history( self, symbol: str, - interval: str, + interval: Interval, start: dt.datetime = None, end: dt.datetime = None, ): @@ -175,17 +177,16 @@ def fetch_price_history( f"Interval {interval} not in interval list. Possible options are: {self.interval_list}" ) val, unit = expand_interval(interval) - return self.get_data_from_kraken(symbol, val, unit, start, end) + df = self.get_data_from_kraken(symbol, val, unit, start, end) + + return df - @API._exception_handler def fetch_chain_info(self, symbol: str): raise NotImplementedError("Kraken does not support options.") - @API._exception_handler def fetch_chain_data(self, symbol: str, date: dt.datetime): raise NotImplementedError("Kraken does not support options.") - @API._exception_handler def fetch_option_market_data(self, occ_symbol: str): raise NotImplementedError("Kraken does not support options.") @@ -203,51 +204,88 @@ def fetch_option_positions(self): @API._exception_handler def fetch_crypto_positions(self): - return self.get_result(self.api.query_private("OpenOrders")) + positions = self.get_result(self.api.query_private("OpenPositions")) + + def fmt(crypto: Dict[str, Any]): + # Remove the currency + symbol = crypto["pair"][:-4] + # Convert from kraken name to crypto currency ticker + symbol = kraken_name_to_crypto_ticker.get(symbol) + return { + "symbol": "@" + symbol, + "avg_price": float(crypto["cost"]) / float(crypto["vol"]), + "quantity": float(crypto["vol"]), + "kraken": crypto, + } + + return [fmt(pos) for pos in positions] - @API._exception_handler def update_option_positions(self, positions: List[Any]): - raise NotImplementedError("Kraken does not support options.") + debugger.error("Kraken does not support options. Doing nothing.") @API._exception_handler def fetch_account(self): - return self.get_result(self.api.query_private("Balance")) + account = self.get_result(self.api.query_private("Balance")) + if account is None: + equity = 0 + cash = 0 + else: + equity = sum(float(v) for k, v in account.items() if k != "ZUSD") + cash = account.get("ZUSD", 0) + return { + "equity": equity, + "cash": cash, + "buying_power": equity + cash, + "multiplier": 1, + "kraken": account, + } - @API._exception_handler def fetch_stock_order_status(self, order_id: str): return NotImplementedError("Kraken does not support stocks.") - @API._exception_handler - def fetch_option_order_status(self, id): + def fetch_option_order_status(self, order_id: str): raise Exception("Kraken does not support options.") @API._exception_handler - def fetch_crypto_order_status(self, id: str): - closed_orders = self.get_result(self.api.query_private("ClosedOrders")) - orders = closed_orders["closed"] + self.fetch_order_queue() - if id in orders.keys(): - return orders.get(id) - raise Exception(f"{id} not found in your orders.") + def fetch_crypto_order_status(self, order_id: str): + order = self.api.query_private("QueryOrders", {"txid": order_id}) + symbol = kraken_names_to_crypto_ticker.get(crypto["descr"]["pair"][:-4]) + return { + "type": "CRYPTO", + "symbol": "@" + symbol, + "id": crypto.key(), + "quantity": float(crypto["vol"]), + "filled_quantity": float(crypto["vol_exec"]), + "side": crypto["descr"]["type"], + "time_in_force": None, + "status": crypto["status"], + "kraken": crypto, + } # --------------- Methods for Trading --------------- # @API._exception_handler def fetch_order_queue(self): open_orders = self.get_result(self.api.query_private("OpenOrders")) - return open_orders["open"] - - def order_stock_limit( - self, - symbol: str, - quantity: float, - limit_price: float, - order_type: str, - time_in_force: str, - order_id: str = None, - ): - raise NotImplementedError("Kraken does not support stocks.") - - def order_crypto_limit( + open_orders = open_orders["open"] + + def fmt(crypto: Dict[str, Any]): + symbol = kraken_names_to_crypto_ticker.get(crypto["descr"]["pair"][:-4]) + return { + "type": "CRYPTO", + "symbol": "@" + symbol, + "id": crypto.key(), + "quantity": float(crypto["vol"]), + "filled_quantity": float(crypto["vol_exec"]), + "side": crypto["descr"]["type"], + "time_in_force": None, + "status": crypto["status"], + "kraken": crypto, + } + + return [fmt(order) for order in open_orders] + + def order_limit( self, side: str, symbol: str, @@ -256,10 +294,12 @@ def order_crypto_limit( in_force: str = "gtc", extended: bool = False, ): + if is_crypto(symbol): + symbol = ticker_to_kraken(symbol) + else: + raise Exception("Kraken does not support stocks.") - symbol = self.ticker_to_kraken(symbol) - - return self.get_result( + order = self.get_result( self.api.query_private( "AddOrder", { @@ -271,6 +311,13 @@ def order_crypto_limit( ) ) + return { + "type": "CRYPTO", + "id": order["txid"], + "symbol": symbol, + "kraken": order, + } + def order_option_limit( self, side: str, @@ -336,6 +383,29 @@ def _format_df(self, df: pd.DataFrame, symbol: str): return df.dropna() + def ticker_to_kraken(self, ticker: str): + if not is_crypto(ticker): + raise Exception("Kraken does not support stocks.") + + if ticker[1:] in self.crypto_ticker_to_kraken_names: + # Currently Harvest supports trades for USD and not other currencies. + kraken_ticker = self.crypto_ticker_to_kraken_names.get(ticker[1:]) + "USD" + asset_pairs = self.get_result(self.api.query_public("AssetPairs")).keys() + if kraken_ticker in asset_pairs: + return kraken_ticker + else: + raise Exception(f"{kraken_ticker} is not a valid asset pair.") + else: + raise Exception(f"Kraken does not support ticker {ticker}.") + + def get_result(self, response: Dict[str, Any]): + """Given a kraken response from an endpoint, either raise an error if an + error exists or return the data in the results key. + """ + if len(response["error"]) > 0: + raise Exception("\n".join(response["error"])) + return response.get("result", None) + def create_secret(self, path: str) -> bool: import harvest.wizard as wizard @@ -382,26 +452,3 @@ def create_secret(self, path: str) -> bool: ) return True - - def ticker_to_kraken(self, ticker: str): - if not is_crypto(ticker): - raise Exception("Kraken does not support stocks.") - - if ticker[1:] not in self.crypto_ticker_to_kraken_names: - raise Exception(f"Kraken does not support ticker {ticker}.") - - # Currently Harvest supports trades for USD and not other currencies. - kraken_ticker = self.crypto_ticker_to_kraken_names.get(ticker[1:]) + "USD" - asset_pairs = self.get_result(self.api.query_public("AssetPairs")).keys() - if kraken_ticker in asset_pairs: - return kraken_ticker - else: - raise Exception(f"{kraken_ticker} is not a valid asset pair.") - - def get_result(self, response: Dict[str, Any]): - """Given a kraken response from an endpoint, either raise an error if an - error exists or return the data in the results key. - """ - if len(response["error"]) > 0: - raise Exception("\n".join(response["error"])) - return response["result"] diff --git a/harvest/api/paper.py b/harvest/api/paper.py index 4661b437..99d00e14 100644 --- a/harvest/api/paper.py +++ b/harvest/api/paper.py @@ -9,6 +9,7 @@ # Submodule imports from harvest.api._base import API +from harvest.api.dummy import DummyStreamer from harvest.utils import * @@ -27,7 +28,7 @@ class PaperBroker(API): Interval.DAY_1, ] - def __init__(self, account_path: str = None, commission_fee=0): + def __init__(self, account_path: str = None, commission_fee=0, streamer=None): """ :commission_fee: When this is a number it is assumed to be a flat price on all buys and sells of assets. When this is a string formatted as @@ -49,6 +50,7 @@ def __init__(self, account_path: str = None, commission_fee=0): self.multiplier = 1 self.commission_fee = commission_fee self.id = 0 + self.streamer = DummyStreamer() if streamer is None else streamer if account_path: with open(account_path, "r") as f: @@ -64,8 +66,8 @@ def __init__(self, account_path: str = None, commission_fee=0): for crypto in account["cryptos"]: self.cryptos.append(crypto) - def setup(self, interval, trader=None, trader_main=None): - super().setup(interval, trader, trader_main) + def setup(self, interval, trader_main=None): + super().setup(interval, trader_main) # -------------- Streamer methods -------------- # @@ -89,11 +91,7 @@ def fetch_crypto_positions(self) -> List[Dict[str, Any]]: def update_option_positions(self, positions) -> List[Dict[str, Any]]: for r in self.options: occ_sym = r["symbol"] - - if self.trader is None: - price = self.fetch_option_market_data(occ_sym)["price"] - else: - price = self.trader.streamer.fetch_option_market_data(occ_sym)["price"] + price = self.streamer.fetch_option_market_data(occ_sym)["price"] r["current_price"] = price r["market_value"] = price * r["quantity"] * 100 @@ -112,17 +110,12 @@ def fetch_stock_order_status(self, id: int) -> Dict[str, Any]: ret = next(r for r in self.orders if r["id"] == id) sym = ret["symbol"] - if self.trader is None: - price = self.streamer.fetch_price_history( - sym, - self.interval[sym]["interval"], - dt.datetime.now() - dt.timedelta(days=7), - dt.datetime.now(), - )[sym]["close"][-1] - else: - price = self.trader.storage.load(sym, self.interval[sym]["interval"])[sym][ - "close" - ][-1] + price = self.streamer.fetch_price_history( + sym, + self.interval[sym]["interval"], + self.streamer.get_current_time() - dt.timedelta(days=7), + self.streamer.get_current_time(), + )[sym]["close"][-1] qty = ret["quantity"] original_price = price * qty @@ -160,7 +153,7 @@ def fetch_stock_order_status(self, id: int) -> Dict[str, Any]: self.orders.remove(ret) ret = ret_1 ret["status"] = "filled" - ret["filled_time"] = self.trader.timestamp + ret["filled_time"] = self.streamer.get_current_time() ret["filled_price"] = price else: if pos is None: @@ -178,7 +171,7 @@ def fetch_stock_order_status(self, id: int) -> Dict[str, Any]: self.orders.remove(ret) ret = ret_1 ret["status"] = "filled" - ret["filled_time"] = self.trader.timestamp + ret["filled_time"] = self.streamer.get_current_time() ret["filled_price"] = price self.equity = self._calc_equity() @@ -194,10 +187,7 @@ def fetch_option_order_status(self, id: int) -> Dict[str, Any]: sym = ret["base_symbol"] occ_sym = ret["symbol"] - if self.trader is None: - price = self.streamer.fetch_option_market_data(occ_sym)["price"] - else: - price = self.trader.streamer.fetch_option_market_data(occ_sym)["price"] + price = self.streamer.fetch_option_market_data(occ_sym)["price"] qty = ret["quantity"] original_price = price * qty @@ -243,7 +233,7 @@ def fetch_option_order_status(self, id: int) -> Dict[str, Any]: self.cash -= actual_price self.buying_power -= actual_price ret["status"] = "filled" - ret["filled_time"] = self.trader.timestamp + ret["filled_time"] = self.streamer.get_current_time() ret["filled_price"] = price debugger.debug(f"After BUY: {self.buying_power}") ret_1 = ret.copy() @@ -265,7 +255,7 @@ def fetch_option_order_status(self, id: int) -> Dict[str, Any]: if pos["quantity"] < 1e-8: self.options.remove(pos) ret["status"] = "filled" - ret["filled_time"] = self.trader.timestamp + ret["filled_time"] = self.streamer.get_current_time() ret["filled_price"] = price ret_1 = ret.copy() self.orders.remove(ret) @@ -324,7 +314,7 @@ def order_crypto_limit( ): data = { "type": "CRYPTO", - "symbol": '@'+symbol, + "symbol": "@" + symbol, "quantity": quantity, "filled_qty": quantity, "limit_price": limit_price, diff --git a/harvest/api/polygon.py b/harvest/api/polygon.py index 717fed54..6de3b270 100644 --- a/harvest/api/polygon.py +++ b/harvest/api/polygon.py @@ -21,7 +21,7 @@ def __init__(self, path: str = None, is_basic_account: bool = False): super().__init__(path) self.basic = is_basic_account - def setup(self, interval, trader=None, trader_main=None): + def setup(self, interval, trader_main=None): self.watch_stock = [] self.watch_crypto = [] @@ -32,7 +32,7 @@ def setup(self, interval, trader=None, trader_main=None): self.watch_stock.append(sym) self.option_cache = {} - super().setup(interval, trader, trader_main) + super().setup(interval, trader_main) def exit(self): self.option_cache = {} @@ -47,7 +47,9 @@ def main(self): return for s in combo: - df = self.fetch_price_history(s, Interval.MIN_1, now() - dt.timedelta(days=1), now()).iloc[-1] + df = self.fetch_price_history( + s, Interval.MIN_1, now() - dt.timedelta(days=1), now() + ).iloc[-1] df_dict[s] = df debugger.debug(df) self.trader_main(df_dict) diff --git a/harvest/api/robinhood.py b/harvest/api/robinhood.py index f98ae041..f67c755d 100644 --- a/harvest/api/robinhood.py +++ b/harvest/api/robinhood.py @@ -42,9 +42,9 @@ def refresh_cred(self): debugger.debug("Logged into Robinhood...") # @API._run_once - def setup(self, interval, trader=None, trader_main=None): + def setup(self, interval, trader_main=None): - super().setup(interval, trader, trader_main) + super().setup(interval, trader_main) # Robinhood only supports 15SEC, 1MIN interval for crypto for sym in interval: @@ -388,7 +388,7 @@ def fetch_option_order_status(self, id): "time_in_force": ret["time_in_force"], "status": ret["state"], "filled_time": filled_time, - "filled_price": filled_price + "filled_price": filled_price, } @API._exception_handler @@ -551,7 +551,7 @@ def order_crypto_limit( return { "type": "CRYPTO", "id": ret["id"], - "symbol": '@'+ symbol, + "symbol": "@" + symbol, } except: debugger.error("Error while placing order.\nReturned: {ret}", exc_info=True) @@ -630,9 +630,9 @@ def _format_df( df.columns = pd.MultiIndex.from_product([watch, df.columns]) return df.dropna() - + def _rh_datestr_to_datetime(self, date_str: str): - date_str = date_str[:-3]+date_str[-2:] + date_str = date_str[:-3] + date_str[-2:] return dt.datetime.strptime(date_str, "%Y-%m-%dT%H:%M:%S.%f%z") def create_secret(self, path): diff --git a/harvest/api/webull.py b/harvest/api/webull.py index 323f7700..1110e980 100644 --- a/harvest/api/webull.py +++ b/harvest/api/webull.py @@ -68,8 +68,8 @@ def enter_live_trade_pin(self): return return self.api.get_trade_token(self.config["wb_trade_pin"]) - def setup(self, interval: Dict, trader=None, trader_main=None): - super().setup(interval, trader, trader_main) + def setup(self, interval: Dict, trader_main=None): + super().setup(interval, trader_main) self.watch_stock = [] self.watch_crypto = [] self.watch_crypto_fmt = [] diff --git a/harvest/api/yahoo.py b/harvest/api/yahoo.py index 3f5eb42b..0d974902 100644 --- a/harvest/api/yahoo.py +++ b/harvest/api/yahoo.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Tuple # External libraries +import tzlocal import pandas as pd import yfinance as yf @@ -25,8 +26,8 @@ class YahooStreamer(API): def __init__(self, path=None): self.timestamp = now() - def setup(self, interval: Dict, trader=None, trader_main=None): - super().setup(interval, trader, trader_main) + def setup(self, interval: Dict, trader_main=None): + super().setup(interval, trader_main) self.watch_ticker = {} @@ -169,7 +170,8 @@ def fetch_chain_info(self, symbol: str): return { "id": "n/a", "exp_dates": [ - convert_input_to_datetime(s, self.trader.timezone) for s in option_list + convert_input_to_datetime(s, tzlocal.get_localzone()) + for s in option_list ], "multiplier": 100, } diff --git a/harvest/cli.py b/harvest/cli.py index ce65a3e0..bbe4b9d6 100644 --- a/harvest/cli.py +++ b/harvest/cli.py @@ -101,9 +101,8 @@ def raise_helper(): help="buys and sells assets on your behalf", choices=list(brokers.keys()), ) -# start_parser.add_argument( -# "algos", nargs="+", help="paths to algorithms you want to run" -# ) + +# Directory with algos that you want to run, default is the current working directory. start_parser.add_argument( "-d", "--directory", @@ -117,9 +116,8 @@ def raise_helper(): # Parser for visualing data visualize_parser = subparsers.add_parser("visualize") -visualize_parser.add_argument( - "path", help="path to harvest generated data file" -) +visualize_parser.add_argument("path", help="path to harvest generated data file") + def main(): """ @@ -151,29 +149,32 @@ def start(args: argparse.Namespace, test: bool = False): debug = args.debug trader = LiveTrader(streamer=streamer, broker=broker, storage=storage, debug=debug) - # algos is a list of paths to files that have user defined algos + # Get the directories. directory = args.directory print(f"Searching directory {directory}") files = [fi for fi in listdir(directory) if isfile(join(directory, fi))] print(f"Found files {files}") + # For each file in the directory... for f in files: names = f.split(".") + # Filter out non-python files. if len(names) <= 1 or names[-1] != "py": continue name = "".join(names[:-1]) + # ...open it... with open(join(directory, f), "r") as algo_file: firstline = algo_file.readline() if firstline.find("HARVEST_SKIP") != -1: print(f"Skipping {f}") continue - # load in the entire file + # ...load in the entire file and add the algo to the trader. algo_path = os.path.realpath(join(directory, f)) spec = importlib.util.spec_from_file_location(name, algo_path) algo = importlib.util.module_from_spec(spec) spec.loader.exec_module(algo) - # iterate though the variables and if a variable is a subclass of BaseAlgo instantiate it and added to the trader + # Iterate though the variables and if a variable is a subclass of BaseAlgo instantiate it and added to the trader. for algo_cls in inspect.getmembers(algo): k, v = algo_cls[0], algo_cls[1] if inspect.isclass(v) and v != BaseAlgo and issubclass(v, BaseAlgo): @@ -183,11 +184,17 @@ def start(args: argparse.Namespace, test: bool = False): if not test: trader.start() + def visualize(args: argparse.Namespace): + """ + Read a csv or pickle file created by Harvest with ohlc data and graph the data. + :args: A Namespace object containing parsed user arguments. + """ import re import pandas as pd import mplfinance as mpf + # Open the file using the appropriate parser. if args.path.endswith(".csv"): df = pd.read_csv(args.path) df["timestamp"] = pd.to_datetime(df["timestamp"]) @@ -200,23 +207,36 @@ def visualize(args: argparse.Namespace): if df.empty: print(f"No data found in {args.path}.") + sys.exit(2) path = os.path.basename(args.path) + # File names are asset {ticker name}@{interval}.{file format} file_search = re.search("^(@?[\w]+)@([\w]+).(csv|pickle)$", path) symbol, interval = file_search.group(1), file_search.group(2) open_price = df.iloc[0]["open"] close_price = df.iloc[-1]["close"] high_price = df["high"].max() low_price = df["low"].min() + price_delta = close_price - open_price + price_delta_precent = 100 % (price_delta / open_price) + volume = df["volume"].sum() print(f"{symbol} at {interval}") - print("open", open_price) - print("high", high_price) - print("low", low_price) - print("close", close_price) - print("price change", close_price - open_price) - mpf.plot(df, type="candle", volume=True, show_nontrading=True) - + print(f"open\t{open_price}") + print(f"high\t{high_price}") + print(f"low\t{low_price}") + print(f"close\t{close_price}") + print(f"price change\t{price_delta}") + print(f"price change percentage\t{price_delta_precent}%") + print(f"volume\t{volume}") + mpf.plot( + df, + type="candle", + style="charles", + volume=True, + show_nontrading=True, + title=path, + ) def _get_storage(storage: str): diff --git a/harvest/storage/__init__.py b/harvest/storage/__init__.py index fe657b5e..40f3f3c6 100644 --- a/harvest/storage/__init__.py +++ b/harvest/storage/__init__.py @@ -1,5 +1,3 @@ from harvest.storage.base_storage import BaseStorage from harvest.storage.csv_storage import CSVStorage from harvest.storage.pickle_storage import PickleStorage - -from harvest.storage.base_logger import BaseLogger diff --git a/harvest/storage/base_storage.py b/harvest/storage/base_storage.py index 78f8cf1b..c234f320 100644 --- a/harvest/storage/base_storage.py +++ b/harvest/storage/base_storage.py @@ -43,6 +43,7 @@ as long as they implement the API properly. """ + class BaseStorage: """ A basic storage that is thread safe and stores data in memory. @@ -58,22 +59,22 @@ class BaseStorage: ] def __init__( - self, - price_storage_size: int = 200, - price_storage_limit: bool = True, - transaction_storage_size: int = 200, - transaction_storage_limit: bool = True, - performance_storage_size: int = 200, - performance_storage_limit: bool = True, - ): + self, + price_storage_size: int = 200, + price_storage_limit: bool = True, + transaction_storage_size: int = 200, + transaction_storage_limit: bool = True, + performance_storage_size: int = 200, + performance_storage_limit: bool = True, + ): """ - queue_size: The maximum number of data points to store for asset price history. + queue_size: The maximum number of data points to store for asset price history. This helps prevent the database from becoming infinitely large as time progresses. - limit_size: Whether to limit the size of price history to queue_size. + limit_size: Whether to limit the size of price history to queue_size. This may be set to False if the storage is being used for backtesting, in which case you would want to store as much data as possible. """ - self.storage_lock = Lock() # Lock + self.storage_lock = Lock() # Lock self.price_storage_size = price_storage_size self.price_storage_limit = price_storage_limit @@ -84,10 +85,17 @@ def __init__( # BaseStorage uses a python dictionary to store the data, # where key is asset symbol and value is a pandas dataframe. - self.storage_price = {} + self.storage_price = {} self.storage_transaction = pd.DataFrame( - columns=["timestamp", "algorithm_name", "symbol", "side", "quantity", "price"] + columns=[ + "timestamp", + "algorithm_name", + "symbol", + "side", + "quantity", + "price", + ] ) self.storage_performance = {} @@ -95,14 +103,6 @@ def __init__( self.storage_performance[interval] = pd.DataFrame( columns=["equity"], index=[] ) - - def setup(self, trader): - """ - Sets up the storage - """ - self.trader = trader - - def store( self, symbol: str, interval: Interval, data: pd.DataFrame, remove_duplicate=True @@ -131,8 +131,11 @@ def store( intervals[interval], data, remove_duplicate=remove_duplicate ) if self.price_storage_limit: - intervals[interval] = intervals[interval][-self.price_storage_size:] + intervals[interval] = intervals[interval][ + -self.price_storage_size : + ] except: + self.storage_lock.release() raise Exception("Append Failure, case not found!") else: # Add the data as a new interval @@ -150,13 +153,12 @@ def store( cur_len = len(self.storage_price[symbol][interval]) if self.price_storage_limit and cur_len > self.price_storage_size: # If we have more than N data points, remove the oldest data - self.storage_price[symbol][interval] = self.storage_price[symbol][interval].iloc[ - -self.price_storage_size : - ] + self.storage_price[symbol][interval] = self.storage_price[symbol][ + interval + ].iloc[-self.price_storage_size :] self.storage_lock.release() - def load( self, symbol: str, @@ -190,7 +192,9 @@ def load( ] intervals.sort(key=lambda interval_timedelta: interval_timedelta[1]) for interval_timedelta in intervals: + self.storage_lock.release() data = self.load(symbol, interval_timedelta[0], start, end) + self.storage_lock.acquire() if data is not None: self.storage_lock.release() return data @@ -201,7 +205,7 @@ def load( if start is None and end is None: self.storage_lock.release() - return data + return data # If the start and end are not defined, then set them to the # beginning and end of the data. @@ -210,20 +214,21 @@ def load( if end is None: end = data.index[-1] + self.storage_lock.release() return data.loc[start:end] - + def store_transaction( - self, + self, timestamp: dt.datetime, algorithm_name: str, - symbol: str, + symbol: str, side: str, quantity: int, - price: float + price: float, ) -> None: self.storage_transaction.append( [timestamp, algorithm_name, symbol, side, quantity, price], - ignore_index=True + ignore_index=True, ) def reset(self, symbol: str, interval: Interval): @@ -234,7 +239,6 @@ def reset(self, symbol: str, interval: Interval): self.storage_price[symbol][interval] = pd.DataFrame() self.storage_lock.release() - def _append( self, current_data: pd.DataFrame, @@ -267,22 +271,24 @@ def aggregate( self.storage_lock.acquire() data = self.storage_price[symbol][base] self.storage_price[symbol][target] = self._append( - self.storage_price[symbol][target], aggregate_df(data, target), remove_duplicate + self.storage_price[symbol][target], + aggregate_df(data, target), + remove_duplicate, ) cur_len = len(self.storage_price[symbol][target]) if self.price_storage_limit and cur_len > self.price_storage_size: - self.storage_price[symbol][target] = self.storage_price[symbol][target].iloc[ - -self.price_storage_size : - ] + self.storage_price[symbol][target] = self.storage_price[symbol][ + target + ].iloc[-self.price_storage_size :] self.storage_lock.release() - def init_performace_data(self, equity: float): + def init_performace_data(self, equity: float, timestamp): for interval, days in self.performance_history_intervals: self.storage_performance[interval] = pd.DataFrame( - {'equity': [equity]}, index=[self.trader.timestamp] + {"equity": [equity]}, index=[timestamp] ) - def add_performance_data(self, equity: float): + def add_performance_data(self, equity: float, timestamp): """ Adds the performance data to the storage. @@ -290,43 +296,39 @@ def add_performance_data(self, equity: float): It takes the current equity and adds it to each interval. - :param equity: Current equity of the account. + :param equity: Current equity of the account. """ - cur_timestamp = self.trader.timestamp # Performance history range up until '3 MONTHS' have the - # same interval as the polling interval of the trader. + # same interval as the polling interval of the trader. for interval, days in self.performance_history_intervals[0:3]: df = self.storage_performance[interval] - cutoff = cur_timestamp - dt.timedelta(days=days) + cutoff = timestamp - dt.timedelta(days=days) if df.index[0] < cutoff: df = df.loc[df.index >= cutoff] - df = df.append(pd.DataFrame({"equity": [equity]}, index=[cur_timestamp])) + df = df.append(pd.DataFrame({"equity": [equity]}, index=[timestamp])) self.storage_performance[interval] = df # Performance history intervals after '3 MONTHS' are populated # only for each day. for interval, days in self.performance_history_intervals[3:5]: df = self.storage_performance[interval] - if df.index[-1].date() == cur_timestamp.date(): - df = df.iloc[:-1] - df = df.append(pd.DataFrame({"equity": [equity]}, index=[cur_timestamp])) + if df.index[-1].date() == timestamp.date(): + df = df.iloc[:-1] + df = df.append(pd.DataFrame({"equity": [equity]}, index=[timestamp])) else: - df = df.append(pd.DataFrame({"equity": [equity]}, index=[cur_timestamp])) - cutoff = cur_timestamp - dt.timedelta(days=days) + df = df.append(pd.DataFrame({"equity": [equity]}, index=[timestamp])) + cutoff = timestamp - dt.timedelta(days=days) if df.index[0] < cutoff: df = df.loc[df.index >= cutoff] self.storage_performance[interval] = df - + df = self.storage_performance["ALL"] - if df.index[-1].date() == cur_timestamp.date(): - df = df.iloc[:-1] - df = df.append(pd.DataFrame({"equity": [equity]}, index=[cur_timestamp])) + if df.index[-1].date() == timestamp.date(): + df = df.iloc[:-1] + df = df.append(pd.DataFrame({"equity": [equity]}, index=[timestamp])) self.storage_performance["ALL"] = df debugger.debug("Performance data added") for k in self.storage_performance: debugger.debug(f"{k}:\n {self.storage_performance[k]}") - - - diff --git a/harvest/storage/csv_storage.py b/harvest/storage/csv_storage.py index 0a600acc..b03859e7 100644 --- a/harvest/storage/csv_storage.py +++ b/harvest/storage/csv_storage.py @@ -61,7 +61,7 @@ def store( if not data.empty: self.storage_lock.acquire() - self.storage[symbol][interval][symbol].to_csv( + self.storage_price[symbol][interval][symbol].to_csv( self.save_dir + f"/{symbol}@{interval_enum_to_string(interval)}.csv" ) self.storage_lock.release() diff --git a/harvest/storage/pickle_storage.py b/harvest/storage/pickle_storage.py index fb0a104d..5bf4462a 100644 --- a/harvest/storage/pickle_storage.py +++ b/harvest/storage/pickle_storage.py @@ -63,7 +63,7 @@ def store( if not data.empty and save_pickle: self.storage_lock.acquire() - self.storage[symbol][interval].to_pickle( + self.storage_price[symbol][interval].to_pickle( self.save_dir + f"/{symbol}@{interval_enum_to_string(interval)}.pickle" ) self.storage_lock.release() diff --git a/harvest/trader/tester.py b/harvest/trader/tester.py index 4c16daea..a9513914 100644 --- a/harvest/trader/tester.py +++ b/harvest/trader/tester.py @@ -15,7 +15,6 @@ import harvest.trader.trader as trader from harvest.api.yahoo import YahooStreamer from harvest.api.paper import PaperBroker -from harvest.storage import BaseLogger from harvest.utils import * @@ -67,8 +66,8 @@ def start( a.config() self._setup(source, interval, aggregations, path, start, end, period) - self.broker.setup(self.interval, self, self.main) - self.streamer.setup(self.interval, self, self.main) + self.broker.setup(self.interval, self.main) + self.streamer.setup(self.interval, self.main) for a in self.algo: a.setup() @@ -119,7 +118,7 @@ def _setup( common_end = None for s in self.interval: for i in [self.interval[s]["interval"]] + self.interval[s]["aggregations"]: - df = self.storage.load(s, i, no_slice=True) + df = self.storage.load(s, i) df = pandas_datetime_to_utc(df, self.timezone) if common_start is None or df.index[0] > common_start: common_start = df.index[0] @@ -147,7 +146,7 @@ def _setup( for s in self.interval: for i in [self.interval[s]["interval"]] + self.interval[s]["aggregations"]: - df = self.storage.load(s, i, no_slice=True).copy() + df = self.storage.load(s, i).copy() df = df.loc[start:end] self.storage.reset(s, i) self.storage.store(s, i, df) @@ -199,7 +198,7 @@ def _setup( debugger.debug("Formatting complete") for sym in self.interval: for agg in self.interval[sym]["aggregations"]: - data = self.storage.load(sym, int(agg) - 16, no_slice=True) + data = self.storage.load(sym, int(agg) - 16) data = pandas_datetime_to_utc(data, self.timezone) self.storage.store( sym, @@ -220,13 +219,13 @@ def _setup( self.df[sym] = {} inter = self.interval[sym]["interval"] interval_txt = interval_enum_to_string(inter) - df = self.storage.load(sym, inter, no_slice=True) + df = self.storage.load(sym, inter) self.df[sym][inter] = df.copy() for agg in self.interval[sym]["aggregations"]: # agg_txt = interval_enum_to_string(agg) # agg_txt = f"{interval_txt}+{agg_txt}" - df = self.storage.load(sym, int(agg) - 16, no_slice=True) + df = self.storage.load(sym, int(agg) - 16) self.df[sym][int(agg) - 16] = df.copy() # Trim data so start and end dates match between assets and intervals @@ -329,7 +328,7 @@ def run_backtest(self): df_dict[sym] = self.df[sym][inter].loc[self.timestamp] update = self._update_order_queue() - self._update_stats(df_dict, new=update, option_update=True) + self._update_position_cache(df_dict, new=update, option_update=True) for sym in self.interval: inter = self.interval[sym]["interval"] if is_freq(self.timestamp, inter): diff --git a/harvest/trader/trader.py b/harvest/trader/trader.py index 529341f8..9b1c1e05 100644 --- a/harvest/trader/trader.py +++ b/harvest/trader/trader.py @@ -19,7 +19,6 @@ from harvest.api.dummy import DummyStreamer from harvest.api.paper import PaperBroker from harvest.storage import BaseStorage -from harvest.storage import BaseLogger from harvest.server import Server @@ -47,10 +46,7 @@ def __init__(self, streamer=None, broker=None, storage=None, debug=False): self._set_streamer_broker(streamer, broker) # Initialize the storage - self.storage = ( - BaseStorage() if storage is None else storage - ) - self.storage.setup(self) + self.storage = BaseStorage() if storage is None else storage self._init_attributes() self._setup_debugger(debug) @@ -78,9 +74,6 @@ def _init_attributes(self): signal(SIGINT, self.exit) - # Initialize timestamp - self.timestamp = self.streamer.timestamp - self.watchlist_global = [] # List of securities specified in this class self.algo = [] # List of algorithms to run. self.account = {} # Local cache of account data. @@ -89,7 +82,6 @@ def _init_attributes(self): self.crypto_positions = [] # Local cache of current crypto positions. self.order_queue = [] # Queue of unfilled orders. - self.logger = BaseLogger() self.server = Server(self) # Initialize the web interface server self.timezone = tzlocal.get_localzone() @@ -148,13 +140,15 @@ def start( # Initialize the account self._setup_account() - self.storage.init_performace_data(self.account["equity"]) + self.storage.init_performace_data( + self.account["equity"], self.streamer.timestamp + ) - self.broker.setup(self.interval, self, self.main) + self.broker.setup(self.interval, self.main) if self.broker != self.streamer: # Only call the streamer setup if it is a different # instance than the broker otherwise some brokers can fail! - self.streamer.setup(self.interval, self, self.main) + self.streamer.setup(self.interval, self.main) # Initialize the storage self._storage_init(all_history) @@ -274,19 +268,23 @@ def _storage_init(self, all_history: bool): start = None if all_history else now() - dt.timedelta(days=3) df = self.streamer.fetch_price_history(sym, inter, start) self.storage.store(sym, inter, df) - + # ================== Functions for main routine ===================== def main(self, df_dict): """ Main loop of the Trader. """ - self.timestamp = self.streamer.timestamp # Periodically refresh access tokens - if self.timestamp.hour % 12 == 0 and self.timestamp.minute == 0: + if ( + self.streamer.timestamp.hour % 12 == 0 + and self.streamer.timestamp.minute == 0 + ): self.streamer.refresh_cred() - - self.storage.add_performance_data(self.account["equity"]) + + self.storage.add_performance_data( + self.account["equity"], self.streamer.timestamp + ) # Save the data locally for sym in df_dict: @@ -306,7 +304,7 @@ def main(self, df_dict): new_algo = [] for a in self.algo: - if not is_freq(self.timestamp, a.interval): + if not is_freq(self.streamer.timestamp, a.interval): new_algo.append(a) continue try: @@ -353,14 +351,16 @@ def _update_order_queue(self): # TODO: handle cancelled orders if order["status"] == "filled": order_filled = True - debugger.debug(f"Order {order['id']} filled at {order['filled_time']} at {order['filled_price']}") + debugger.debug( + f"Order {order['id']} filled at {order['filled_time']} at {order['filled_price']}" + ) self.storage.store_transaction( order["filled_time"], "N/A", order["symbol"], order["side"], order["quantity"], - order["filled_price"] + order["filled_price"], ) else: new_order.append(order) @@ -407,6 +407,7 @@ def _update_position_cache(self, df_dict, new=False, option_update=False): equity = net_value + self.account["cash"] self.account["equity"] = equity + self.stock_positions = self.broker.fetch_stock_positions() def _fetch_account_data(self): pos = self.broker.fetch_stock_positions() @@ -435,29 +436,33 @@ def buy(self, symbol: str, quantity: int, in_force: str, extended: bool): if symbol_type(symbol) == "OPTION": price = self.streamer.fetch_option_market_data(symbol)["price"] else: - price = self.storage.load( - symbol, self.interval[symbol]["interval"] - )[symbol]["close"][-1] + price = self.storage.load(symbol, self.interval[symbol]["interval"])[ + symbol + ]["close"][-1] limit_price = mark_up(price) total_price = limit_price * quantity + debugger.warning( + f"Attempting to buy {quantity} shares of {symbol} at price {price} with price limit {limit_price} and a maximum total price of {total_price}" + ) + if total_price >= buy_power: debugger.error( - "Not enough buying power.\n" + - f"Total price ({price} * {quantity} * 1.05 = {limit_price*quantity}) exceeds buying power {buy_power}." + - "Reduce purchase quantity or increase buying power." + "Not enough buying power.\n" + + f"Total price ({price} * {quantity} * 1.05 = {limit_price*quantity}) exceeds buying power {buy_power}." + + "Reduce purchase quantity or increase buying power." ) return None - + # TODO? Perform other checks ret = self.broker.buy(symbol, quantity, limit_price, in_force, extended) - + if ret is None: debugger.debug("BUY failed") return None self.order_queue.append(ret) - debugger.debug(f"BUY: {self.timestamp}, {symbol}, {quantity}") + debugger.debug(f"BUY: {self.streamer.timestamp}, {symbol}, {quantity}") return ret @@ -465,13 +470,15 @@ def sell(self, symbol: str, quantity: int, in_force: str, extended: bool): # Check how many of the given asset we currently own owned_qty = self.get_asset_quantity(symbol, exclude_pending_sell=True) if quantity > owned_qty: - debugger.debug("SELL failed: More quantities are being sold than currently owned.") + debugger.debug( + "SELL failed: More quantities are being sold than currently owned." + ) return None - + if symbol_type(symbol) == "OPTION": - price = self.trader.streamer.fetch_option_market_data(symbol)["price"] + price = self.streamer.fetch_option_market_data(symbol)["price"] else: - price = self.trader.storage.load(symbol, self.interval[symbol]["interval"])[ + price = self.storage.load(symbol, self.interval[symbol]["interval"])[ symbol ]["close"][-1] @@ -482,14 +489,13 @@ def sell(self, symbol: str, quantity: int, in_force: str, extended: bool): debugger.debug("SELL failed") return None self.order_queue.append(ret) - debugger.debug(f"SELL: {self.timestamp}, {symbol}, {quantity}") + debugger.debug(f"SELL: {self.streamer.timestamp}, {symbol}, {quantity}") return ret - + # ================ Helper Functions ====================== def get_asset_quantity( - self, symbol: str = None, - include_pending_buy = False, - exclude_pending_sell = False ) -> float: + self, symbol: str = None, include_pending_buy=False, exclude_pending_sell=False + ) -> float: """Returns the quantity owned of a specified asset. :param str? symbol: Symbol of asset. defaults to first symbol in watchlist @@ -498,40 +504,34 @@ def get_asset_quantity( """ if symbol is None: symbol = self.watchlist_global[0] - + if typ := symbol_type(symbol) == "OPTION": owned_qty = sum( - p["quantity"] - for p in self.option_positions - if p["symbol"] == symbol + p["quantity"] for p in self.option_positions if p["symbol"] == symbol ) elif typ == "CRYPTO": owned_qty = sum( - p["quantity"] - for p in self.crypto_positions - if p["symbol"] == symbol + p["quantity"] for p in self.crypto_positions if p["symbol"] == symbol ) else: owned_qty = sum( - p["quantity"] - for p in self.stock_positions - if p["symbol"] == symbol + p["quantity"] for p in self.stock_positions if p["symbol"] == symbol ) - + if include_pending_buy: owned_qty += sum( o["quantity"] for o in self.order_queue if o["symbol"] == symbol and o["side"] == "buy" ) - + if exclude_pending_sell: owned_qty -= sum( o["quantity"] for o in self.order_queue if o["symbol"] == symbol and o["side"] == "sell" ) - + return owned_qty # def buy_option(self, symbol: str, quantity: int, in_force: str): @@ -539,9 +539,9 @@ def get_asset_quantity( # if ret is None: # raise Exception("BUY failed") # self.order_queue.append(ret) - # debugger.debug(f"BUY: {self.timestamp}, {symbol}, {quantity}") + # debugger.debug(f"BUY: {self.streamer.timestamp}, {symbol}, {quantity}") # debugger.debug(f"BUY order queue: {self.order_queue}") - # self.logger.add_transaction(self.timestamp, "buy", "option", symbol, quantity) + # self.logger.add_transaction(self.streamer.timestamp, "buy", "option", symbol, quantity) # return ret # def sell_option(self, symbol: str, quantity: int, in_force: str): @@ -561,9 +561,9 @@ def get_asset_quantity( # if ret is None: # raise Exception("SELL failed") # self.order_queue.append(ret) - # debugger.debug(f"SELL: {self.timestamp}, {symbol}, {quantity}") + # debugger.debug(f"SELL: {self.streamer.timestamp}, {symbol}, {quantity}") # debugger.debug(f"SELL order queue: {self.order_queue}") - # self.logger.add_transaction(self.timestamp, "sell", "option", symbol, quantity) + # self.logger.add_transaction(self.streamer.timestamp, "sell", "option", symbol, quantity) # return ret def set_algo(self, algo): @@ -608,7 +608,7 @@ def __init__(self, streamer=None, storage=None, debug=False): # If streamer is not specified, use YahooStreamer self.streamer = YahooStreamer() if streamer is None else streamer - self.broker = PaperBroker() + self.broker = PaperBroker(streamer=self.streamer) self.storage = ( BaseStorage() if storage is None else storage diff --git a/harvest/utils.py b/harvest/utils.py index f8742d29..78677b74 100644 --- a/harvest/utils.py +++ b/harvest/utils.py @@ -12,6 +12,7 @@ # External Imports import pandas as pd +# Configure a logger used by all of Harvest. logging.basicConfig( level=logging.INFO, format="%(asctime)s : %(name)s : %(levelname)s : %(message)s", @@ -246,6 +247,7 @@ def datetime_utc_to_local(date_time: dt.datetime, timezone: ZoneInfo) -> dt.date Converts a datetime object in UTC to local time, represented as a timezone naive datetime object. """ + date_time = date_time.to_pydatetime() new_tz = date_time.astimezone(timezone) return new_tz.replace(tzinfo=None) diff --git a/tests/test_algo.py b/tests/test_algo.py index c657f4a0..913acab2 100644 --- a/tests/test_algo.py +++ b/tests/test_algo.py @@ -254,7 +254,7 @@ def test_buy_sell_option_auto(self, mock_mark_up): p = t.option_positions[0] self.assertEqual(p["symbol"], "X 110101C01000000") - t.algo[0].sell_option() + t.algo[0].sell_all_options() streamer.tick() # p = t.stock_positions[0] diff --git a/tests/test_api.py b/tests/test_api.py index f109596a..2f6a3d0e 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -25,7 +25,7 @@ def setUpClass(self): def test_timeout(self): stream = StreamAPI() stream.fetch_account = lambda: None - stream.fetch_price_history = lambda x, y: pd.DataFrame() + stream.fetch_price_history = lambda x, y, z: pd.DataFrame() stream.fetch_account = lambda: {"cash": 100, "equity": 100} t = PaperTrader(stream) stream.trader = t @@ -62,20 +62,20 @@ def test_timeout(self): data["A"]["A"]["close"][-1], t.storage.load("A", Interval.MIN_1)["A"]["close"][-1], ) - # Check if B has been duplicated + # Check if B has been set to 0 self.assertEqual( b_cur["B"]["close"][-1], t.storage.load("B", Interval.MIN_1)["B"]["close"][-2], ) self.assertEqual( - b_cur["B"]["close"][-1], + 0, t.storage.load("B", Interval.MIN_1)["B"]["close"][-1], ) def test_timeout_cancel(self): stream = StreamAPI() stream.fetch_account = lambda: None - stream.fetch_price_history = lambda x, y: pd.DataFrame() + stream.fetch_price_history = lambda x, y, z: pd.DataFrame() stream.fetch_account = lambda: {"cash": 100, "equity": 100} t = PaperTrader(stream) t.set_algo(BaseAlgo()) @@ -130,11 +130,7 @@ def test_timeout_cancel(self): def test_exceptions(self): api = API() - try: - api.create_secret("I dont exists") - self.assertTrue(False) - except Exception as e: - self.assertEqual(str(e), "I dont exists was not found.") + self.assertEqual(api.create_secret("I dont exists"), False) try: api.fetch_price_history('A', Interval.MIN_1, now(), now()) @@ -185,10 +181,16 @@ def test_exceptions(self): self.assertEqual(str(e), "API does not support this broker method: `fetch_crypto_order_status`.") try: - api.order_limit("buy", "A", 5, 7) + api.order_stock_limit("buy", "A", 5, 7) + self.assertTrue(False) + except NotImplementedError as e: + self.assertEqual(str(e), "API does not support this broker method: `order_stock_limit`.") + + try: + api.order_crypto_limit("buy", "@A", 5, 7) self.assertTrue(False) except NotImplementedError as e: - self.assertEqual(str(e), "API does not support this broker method: `order_limit`.") + self.assertEqual(str(e), "API does not support this broker method: `order_crypto_limit`.") try: api.order_option_limit("buy", "A", 5, 7, "call", now(), 8) @@ -196,6 +198,20 @@ def test_exceptions(self): except NotImplementedError as e: self.assertEqual(str(e), "API does not support this broker method: `order_option_limit`.") + try: + api.buy('A', -1, 0) + self.assertTrue(False) + except NotImplementedError as e: + self.assertEqual(str(e), "API does not support this broker method: `order_stock_limit`.") + + try: + api.sell('A', -1, 0) + self.assertTrue(False) + except NotImplementedError as e: + self.assertEqual(str(e), "API does not support this broker method: `order_stock_limit`.") + + + def test_base_cases(self): api = API() @@ -206,15 +222,13 @@ def test_base_cases(self): self.assertEqual(api.fetch_crypto_positions(), []) api.update_option_positions([]) self.assertEqual(api.fetch_order_queue(), []) - self.assertTrue(api.buy('A', -1) is None) - self.assertTrue(api.buy_option("A", -1) is None) def test_run_once(self): api = API() fn = lambda x: x + 1 wrapper = API._run_once(fn) - self.assertEqual(wrapper(api)(5), 6) - self.assertTrue(wrapper(api) is None) + self.assertEqual(wrapper(5), 6) + self.assertTrue(wrapper(5) is None) def test_timestamp(self): api = API() diff --git a/tests/test_api_paper.py b/tests/test_api_paper.py index ec640eba..6b2c9671 100644 --- a/tests/test_api_paper.py +++ b/tests/test_api_paper.py @@ -12,7 +12,6 @@ class TestPaperBroker(unittest.TestCase): def test_account(self): dummy = PaperBroker() - dummy.streamer = DummyStreamer() d = dummy.fetch_account() self.assertEqual(d["equity"], 1000000.0) self.assertEqual(d["cash"], 1000000.0) @@ -22,7 +21,6 @@ def test_account(self): def test_dummy_account(self): directory = pathlib.Path(__file__).parent.resolve() dummy = PaperBroker(str(directory) + "/../dummy_account.yaml") - dummy.streamer = DummyStreamer() stocks = dummy.fetch_stock_positions() self.assertEqual(len(stocks), 2) self.assertEqual(stocks[0]["symbol"], "A") @@ -37,7 +35,6 @@ def test_dummy_account(self): def test_buy_order_limit(self): dummy = PaperBroker() - dummy.streamer = DummyStreamer() interval = {"A": {"interval": Interval.MIN_1, "aggregations": []}} dummy.setup(interval) order = dummy.order_stock_limit("buy", "A", 5, 50000) @@ -56,10 +53,9 @@ def test_buy_order_limit(self): def test_buy(self): dummy = PaperBroker() - dummy.streamer = DummyStreamer() interval = {"A": {"interval": Interval.MIN_1, "aggregations": []}} dummy.setup(interval) - order = dummy.buy("A", 5) + order = dummy.buy("A", 5, 1e5) self.assertEqual(order["type"], "STOCK") self.assertEqual(order["id"], 0) self.assertEqual(order["symbol"], "A") @@ -76,7 +72,6 @@ def test_buy(self): def test_sell_order_limit(self): directory = pathlib.Path(__file__).parent.resolve() dummy = PaperBroker(str(directory) + "/../dummy_account.yaml") - dummy.streamer = DummyStreamer() interval = {"A": {"interval": Interval.MIN_1, "aggregations": []}} dummy.setup(interval) order = dummy.order_stock_limit("sell", "A", 2, 50000) @@ -96,7 +91,6 @@ def test_sell_order_limit(self): def test_sell(self): directory = pathlib.Path(__file__).parent.resolve() dummy = PaperBroker(str(directory) + "/../dummy_account.yaml") - dummy.streamer = DummyStreamer() interval = {"A": {"interval": Interval.MIN_1, "aggregations": []}} dummy.setup(interval) order = dummy.sell("A", 2) @@ -115,19 +109,18 @@ def test_sell(self): def test_order_option_limit(self): dummy = PaperBroker() - dummy.streamer = DummyStreamer() interval = {"A": {"interval": Interval.MIN_1, "aggregations": []}} dummy.setup(interval) - exp_date = dt.datetime.now() + dt.timedelta(hours=5) + exp_date = dt.datetime(2021, 11, 14) + dt.timedelta(hours=5) order = dummy.order_option_limit( "buy", "A", 5, 50000, "OPTION", exp_date, 50001 ) self.assertEqual(order["type"], "OPTION") self.assertEqual(order["id"], 0) - self.assertEqual(order["symbol"], "A 211106P50001000") + self.assertEqual(order["symbol"], "A 211114P50001000") status = dummy.fetch_option_order_status(order["id"]) - self.assertEqual(status["symbol"], "A 211106P50001000") + self.assertEqual(status["symbol"], "A 211114P50001000") self.assertEqual(status["quantity"], 5) def test_commission(self): diff --git a/tests/test_api_webull.py b/tests/test_api_webull.py index 3656cebd..2fe0220d 100644 --- a/tests/test_api_webull.py +++ b/tests/test_api_webull.py @@ -58,7 +58,7 @@ def test_main(df): "@BTC": {"interval": Interval.MIN_1, "aggregations": []}, "SPY": {"interval": Interval.MIN_1, "aggregations": []}, } - wb.setup(interval, None, test_main) + wb.setup(interval, test_main) wb.main() @not_gh_action @@ -72,7 +72,7 @@ def test_main(df): wb = Webull() watch = ["SPY"] interval = {"SPY": {"interval": Interval.MIN_1, "aggregations": []}} - wb.setup(interval, None, test_main) + wb.setup(interval, test_main) wb.main() @not_gh_action diff --git a/tests/test_api_yahoo.py b/tests/test_api_yahoo.py index f9734e67..46994d1a 100644 --- a/tests/test_api_yahoo.py +++ b/tests/test_api_yahoo.py @@ -43,7 +43,7 @@ def test_main(df): yh = YahooStreamer() watch = ["SPY", "AAPL", "@BTC"] - yh.setup(interval, None, test_main) + yh.setup(interval, test_main) yh.main() def test_main_single(self): @@ -54,14 +54,14 @@ def test_main(df): self.assertEqual(df["SPY"].columns[0][0], "SPY") yh = YahooStreamer() - yh.setup(interval, None, test_main) + yh.setup(interval, test_main) yh.main() def test_chain_info(self): t = PaperTrader() yh = YahooStreamer() interval = {"LMND": {"interval": Interval.MIN_1, "aggregations": []}} - yh.setup(interval, t, None) + yh.setup(interval) info = yh.fetch_chain_info("LMND") self.assertGreater(len(info["exp_dates"]), 0) @@ -69,7 +69,7 @@ def test_chain_data(self): t = PaperTrader() yh = YahooStreamer() interval = {"LMND": {"interval": Interval.MIN_1, "aggregations": []}} - yh.setup(interval, t, None) + yh.setup(interval) dates = yh.fetch_chain_info("LMND")["exp_dates"] data = yh.fetch_chain_data("LMND", dates[0]) self.assertGreater(len(data), 0) diff --git a/tests/test_storage.py b/tests/test_storage.py index bd2a7e20..4347be7a 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -11,7 +11,7 @@ class TestBaseStorage(unittest.TestCase): def test_create_storage(self): storage = BaseStorage() - self.assertEqual(storage.storage, {}) + self.assertEqual(storage.storage_price, {}) def test_simple_store(self): storage = BaseStorage() @@ -19,8 +19,8 @@ def test_simple_store(self): storage.store("A", Interval.MIN_1, data.copy(True)) self.assertTrue(not pd.isna(data.iloc[0]["A"]["low"])) - self.assertEqual(list(storage.storage.keys()), ["A"]) - self.assertEqual(list(storage.storage["A"].keys()), [Interval.MIN_1]) + self.assertEqual(list(storage.storage_price.keys()), ["A"]) + self.assertEqual(list(storage.storage_price["A"].keys()), [Interval.MIN_1]) def test_simple_load(self): storage = BaseStorage() diff --git a/tests/test_storage_csv.py b/tests/test_storage_csv.py index 319901f5..156b193f 100644 --- a/tests/test_storage_csv.py +++ b/tests/test_storage_csv.py @@ -18,7 +18,7 @@ def setUpClass(self): def test_create_storage(self): storage = CSVStorage(self.storage_dir) - self.assertEqual(storage.storage, {}) + self.assertEqual(storage.storage_price, {}) def test_simple_store(self): storage = CSVStorage(self.storage_dir) @@ -26,8 +26,8 @@ def test_simple_store(self): storage.store("A", Interval.MIN_1, data.copy(True)) self.assertTrue(not pd.isna(data.iloc[0]["A"]["low"])) - self.assertTrue("A" in list(storage.storage.keys())) - self.assertEqual(list(storage.storage["A"].keys()), [Interval.MIN_1]) + self.assertTrue("A" in list(storage.storage_price.keys())) + self.assertEqual(list(storage.storage_price["A"].keys()), [Interval.MIN_1]) def test_saved_load(self): storage1 = CSVStorage(self.storage_dir) diff --git a/tests/test_storage_pickle.py b/tests/test_storage_pickle.py index 808724ef..59684af6 100644 --- a/tests/test_storage_pickle.py +++ b/tests/test_storage_pickle.py @@ -18,7 +18,7 @@ def setUpClass(self): def test_create_storage(self): storage = PickleStorage(self.storage_dir) - self.assertEqual(storage.storage, {}) + self.assertEqual(storage.storage_price, {}) def test_simple_store(self): storage = PickleStorage(self.storage_dir) @@ -26,8 +26,8 @@ def test_simple_store(self): storage.store("A", Interval.MIN_1, data.copy(True)) self.assertTrue(not pd.isna(data.iloc[0]["A"]["low"])) - self.assertTrue("A" in list(storage.storage.keys())) - self.assertEqual(list(storage.storage["A"].keys()), [Interval.MIN_1]) + self.assertTrue("A" in list(storage.storage_price.keys())) + self.assertEqual(list(storage.storage_price["A"].keys()), [Interval.MIN_1]) def test_saved_load(self): storage1 = PickleStorage(self.storage_dir) diff --git a/tests/test_tester.py b/tests/test_tester.py index 805f64f2..b0bcc67a 100644 --- a/tests/test_tester.py +++ b/tests/test_tester.py @@ -43,7 +43,7 @@ def test_check_aggregation(self): minutes = list(t.storage.load("A", Interval.MIN_1)["A"]["close"])[-200:] days_agg = list( - t.storage.load("A", int(Interval.DAY_1) - 16, no_slice=True)["A"]["close"] + t.storage.load("A", int(Interval.DAY_1) - 16)["A"]["close"] )[-200:] self.assertListEqual(minutes, days_agg)