diff --git a/README.md b/README.md
index 967de5bd..bcf9911d 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
![Header](docs/banner.png)
-Harvest is a Python framework providing a **simple** and **flexible** framework for algorithmic trading. Visit Harvest's [**website**](https://tfukaza.github.io/harvest-website/) for details, tutorials, and documentation.
+Harvest is a simple yet flexible Python framework for algorithmic trading. Paper trade and live trade stocks, cryptos, and options![^1][^2] Visit Harvest's [**website**](https://tfukaza.github.io/harvest-website/) for details, tutorials, and documentation.
@@ -77,3 +77,6 @@ Currently looking for...
- Many of the brokers were also not designed to be used for algo-trading. Excessive access to their API can result in your account getting locked.
- Tutorials and documentation solely exist to provide technical references of the code. They are not recommendations of any specific securities or strategies.
- Use Harvest at your own responsibility. Developers of Harvest take no responsibility for any financial losses you incur by using Harvest. By using Harvest, you certify you understand that Harvest is a software in early development and may contain bugs and unexpected behaviors.
+
+[^1]: What assets you can trade depends on the broker you are using.
+[^2]: Backtesting is also available, but it is not supported for options.
diff --git a/docs/dev.md b/docs/dev.md
index 3207ae8d..e92d77d3 100644
--- a/docs/dev.md
+++ b/docs/dev.md
@@ -1,7 +1,22 @@
-# Overview
+## The Harvest Workflow
-### Harvest architecture
-![Harvest architecture](harvest_architecture.png)
+Because Harvest is an extensive framework it can be hard to understand how the system works at times. This document server to provide a high-level overview of just how Harvests works after a user starts the trader.
-### Harvest Flow on Trader `start()`
-![Harvest start flow](harvest_start_flow.png)
\ No newline at end of file
+## Fetching Data
+
+After the user starts the trader Harvest will fetch data from the streamer and update its storage on interval.
+
+![Fetching Data Workflow](fetch-data.png)
+
+1. First Harvest will run the streamer on the specified interval. Once the data has been collected, the streamer will call a callback or hook function that will pass operation back to the trader. In this callback function the streamer will return the latest OHLC data for the assets specified by the trader.
+2. In this callback, the trader will update the storage with the latest data and the will run each algorithm.
+
+## Running Algorithms
+
+After data is fetched, the algorithms are run linearly.
+
+![Running Algorithm Workflow](run-algo.png)
+
+1. The algorithm the user created will user functions provided in the `BaseAlgo` class which communicate with the Trader.
+2. Typically the user's algorithms will first ask for data on the assets they specified which will be stored in the Storage.
+3. After that the user's algoirthms will decided when to buy or sell assets based on the data they got from the Storage. This will leverage the Broker.
\ No newline at end of file
diff --git a/docs/fetch-data.png b/docs/fetch-data.png
new file mode 100644
index 00000000..30a4eb65
Binary files /dev/null and b/docs/fetch-data.png differ
diff --git a/docs/harvest_architecture.png b/docs/harvest_architecture.png
deleted file mode 100644
index 239acb3d..00000000
Binary files a/docs/harvest_architecture.png and /dev/null differ
diff --git a/docs/harvest_start_flow.png b/docs/harvest_start_flow.png
deleted file mode 100644
index e14cd4a5..00000000
Binary files a/docs/harvest_start_flow.png and /dev/null differ
diff --git a/docs/run-algo.png b/docs/run-algo.png
new file mode 100644
index 00000000..c01f09f7
Binary files /dev/null and b/docs/run-algo.png differ
diff --git a/examples/em_alpaca.py b/examples/em_alpaca.py
index 8213b40b..9ee6f518 100644
--- a/examples/em_alpaca.py
+++ b/examples/em_alpaca.py
@@ -1,3 +1,4 @@
+# HARVEST_SKIP
# Builtin imports
import logging
import datetime as dt
diff --git a/examples/em_kraken.py b/examples/em_kraken.py
new file mode 100644
index 00000000..909dd88b
--- /dev/null
+++ b/examples/em_kraken.py
@@ -0,0 +1,100 @@
+# HARVEST_SKIP
+# Builtin imports
+import logging
+import datetime as dt
+
+# Harvest imports
+from harvest.algo import BaseAlgo
+from harvest.trader import LiveTrader
+from harvest.api.kraken import Kraken
+from harvest.storage.csv_storage import CSVStorage
+
+# Third-party imports
+import pandas as pd
+import matplotlib.pyplot as plt
+import mplfinance as mpf
+
+
+class EMAlgo(BaseAlgo):
+ def setup(self):
+ now = dt.datetime.now()
+ logging.info(f"EMAlgo.setup ran at: {now}")
+
+ def init_ticker(ticker):
+ fig = mpf.figure()
+ ax1 = fig.add_subplot(2, 1, 1)
+ ax2 = fig.add_subplot(3, 1, 3)
+
+ return {
+ ticker: {
+ "initial_price": None, "ohlc": pd.DataFrame(),
+ "fig": fig,
+ "ax1": ax1,
+ "ax2": ax2
+ }
+ }
+
+ self.tickers = {}
+ # self.tickers.update(init_ticker("@BTC"))
+ self.tickers.update(init_ticker("@DOGE"))
+
+ def main(self):
+ now = dt.datetime.now()
+ logging.info(f"EMAlgo.main ran at: {now}")
+
+ if now - now.replace(hour=0, minute=0, second=0, microsecond=0) <= dt.timedelta(
+ seconds=60
+ ):
+ logger.info(f"It's a new day! Clearning OHLC caches!")
+ for ticker_value in self.tickers.values():
+ ticker_value["ohlc"] = pd.DataFrame()
+
+ for ticker, ticker_value in self.tickers.items():
+ current_price = self.get_asset_price(ticker)
+ current_ohlc = self.get_asset_candle(ticker)
+ if ticker_value["initial_price"] is None:
+ ticker_value["initial_price"] = current_price
+
+ if current_ohlc.empty:
+ logging.warn(f"{ticker}'s get_asset_candle_list returned an empty list.")
+ return
+
+ ticker_value["ohlc"] = ticker_value["ohlc"].append(current_ohlc)
+
+ self.process_ticker(ticker, ticker_value, current_price)
+
+ def process_ticker(self, ticker, ticker_data, current_price):
+ initial_price = ticker_data["initial_price"]
+ ohlc = ticker_data["ohlc"]
+
+ # Calculate the price change
+ delta_price = current_price - initial_price
+
+ # Print stock info
+ logging.info(f"{ticker} current price: ${current_price}")
+ logging.info(f"{ticker} price change: ${delta_price}")
+
+ # Update the OHLC graph
+ ticker_data['ax1'].clear()
+ ticker_data['ax2'].clear()
+ mpf.plot(ohlc, ax=ticker_data['ax1'], volume=ticker_data['ax2'], type="candle")
+ plt.pause(3)
+
+
+if __name__ == "__main__":
+ # Store the OHLC data in a folder called `em_storage` with each file stored as a csv document
+ csv_storage = CSVStorage(save_dir="em_storage")
+ # Our streamer and broker will be Alpaca. My secret keys are stored in `alpaca_secret.yaml`
+ kraken = Kraken(
+ path="accounts/kraken-secret.yaml"
+ )
+ em_algo = EMAlgo()
+ trader = LiveTrader(streamer=kraken, broker=kraken, storage=csv_storage, debug=True)
+
+ # trader.set_symbol("@BTC")
+ trader.set_symbol("@DOGE")
+ trader.set_algo(em_algo)
+ mpf.show()
+
+ # Update every minute
+ trader.start("1MIN", all_history=False)
diff --git a/examples/em_polygon.py b/examples/em_polygon.py
index 4f6430f8..95ebcda6 100644
--- a/examples/em_polygon.py
+++ b/examples/em_polygon.py
@@ -1,3 +1,4 @@
+# HARVEST_SKIP
# Builtin imports
import logging
import datetime as dt
diff --git a/harvest/algo.py b/harvest/algo.py
index de12af96..18610722 100644
--- a/harvest/algo.py
+++ b/harvest/algo.py
@@ -141,7 +141,7 @@ def sell(
debugger.debug(f"Algo SELL: {symbol}, {quantity}")
return self.trader.sell(symbol, quantity, in_force, extended)
-
+
def sell_all_options(self, symbol: str = None, in_force: str = "gtc"):
"""Sells all options of a stock
@@ -163,8 +163,8 @@ def sell_all_options(self, symbol: str = None, in_force: str = "gtc"):
for s in symbols:
debugger.debug(f"Algo SELL OPTION: {s}")
quantity = self.get_asset_quantity(s)
- ret.append(self.trader.sell_option(s, quantity, in_force))
-
+ ret.append(self.trader.sell(s, quantity, in_force, True))
+
return ret
# def buy_option(self, symbol: str, quantity: int = None, in_force: str = "gtc"):
@@ -522,9 +522,8 @@ def get_asset_quantity(self, symbol: str = None) -> float:
"""
if symbol is None:
symbol = self.watchlist[0]
-
- return self.trader.get_asset_quantity(symbol, exclude_pending_sell=True)
+ return self.trader.get_asset_quantity(symbol, exclude_pending_sell=True)
def get_asset_cost(self, symbol: str = None) -> float:
"""Returns the average cost of a specified asset.
@@ -771,7 +770,9 @@ def get_datetime(self):
:returns: The current date and time as a datetime object
"""
- return datetime_utc_to_local(self.trader.timestamp, self.trader.timezone)
+ return datetime_utc_to_local(
+ self.trader.streamer.timestamp, self.trader.timezone
+ )
def get_option_position_quantity(self, symbol: str = None) -> bool:
"""Returns the number of types of options held for a stock.
diff --git a/harvest/api/_base.py b/harvest/api/_base.py
index 7099cf51..e0c86b8f 100644
--- a/harvest/api/_base.py
+++ b/harvest/api/_base.py
@@ -22,7 +22,7 @@ class API:
Attributes
:interval_list: A list of supported intervals.
- :exchange: The market the API trades on. Ignored if the API is not a broker.
+ :exchange: The market the API trades on. Ignored if the API is not a broker.
"""
interval_list = [
@@ -53,11 +53,6 @@ def __init__(self, path: str = None):
:path: path to the YAML file containing credentials to communicate with the API.
If not specified, defaults to './secret.yaml'
"""
- self.trader = (
- None # Allows broker to handle the case when runs without a trader
- )
-
- self.run_count = 0
if path is None:
path = "./secret.yaml"
@@ -69,14 +64,15 @@ def __init__(self, path: str = None):
with open(path, "r") as stream:
self.config = yaml.safe_load(stream)
- self.run_count = 0
self.timestamp = now()
def create_secret(self, path: str):
"""
This method is called when the yaml file with credentials
is not found."""
- raise Exception(f"{path} was not found.")
+ # raise Exception(f"{path} was not found.")
+ debugger.warning(f"Assuming API does not need account information.")
+ return False
def refresh_cred(self):
"""
@@ -85,14 +81,15 @@ def refresh_cred(self):
"""
debugger.info(f"Refreshing credentials for {type(self).__name__}.")
- def setup(self, interval: Dict, trader=None, trader_main=None) -> None:
+ def setup(self, interval: Dict, trader_main=None) -> None:
"""
This function is called right before the algorithm begins,
and initializes several runtime parameters like
the symbols to watch and what interval data is needed.
+
+ :trader_main: A callback function to the trader which will pass the data to the algorithms.
"""
- self.trader = trader
self.trader_main = trader_main
min_interval = None
@@ -182,12 +179,12 @@ def main(self):
df_dict = {}
for sym in self.interval:
inter = self.interval[sym]["interval"]
- if is_freq(self.timestamp, inter):
- n = self.timestamp
+ if is_freq(harvest_timestamp, inter):
+ n = harvest_timestamp
latest = self.fetch_price_history(
sym, inter, n - interval_to_timedelta(inter) * 2, n
)
- debugger.debug(f"Price fetch returned: \n{latest}")
+ debugger.debug(f"{sym} price fetch returned: {latest}")
if latest is None or latest.empty:
continue
df_dict[sym] = latest.iloc[-1]
@@ -229,29 +226,40 @@ def wrapper(*args, **kwargs):
return wrapper
def _run_once(func):
- """ """
+ """
+ Wrapper to only allows wrapped functions to be run once.
+
+ :func: Function to wrap.
+ :returns: The return of the inputted function if it has not been run before and None otherwise.
+ """
+
+ ran = False
def wrapper(*args, **kwargs):
- self = args[0]
- if self.run_count == 0:
- self.run_count += 1
- return func(args, kwargs)
+ nonlocal ran
+ if not ran:
+ ran = True
+ return func(*args, **kwargs)
return None
return wrapper
# -------------- Streamer methods -------------- #
+ def get_current_time(self):
+ return now()
+
def fetch_price_history(
self,
symbol: str,
interval: Interval,
start: dt.datetime = None,
end: dt.datetime = None,
- ):
+ ) -> pd.DataFrame:
"""
Fetches historical price data for the specified asset and period
- using the API.
+ using the API. The first row is the earliest entry and the last
+ row is the latest entry.
:param symbol: The stock/crypto to get data for.
:param interval: The interval of requested historical data.
@@ -306,10 +314,10 @@ def fetch_option_market_data(self, symbol: str):
raise NotImplementedError(
f"{type(self).__name__} does not support this streamer method: `fetch_option_market_data`."
)
-
+
def fetch_market_hours(self, date: datetime.date):
"""
- Returns the market hours for a given day.
+ Returns the market hours for a given day.
Hours are based on the exchange specified in the class's 'exchange' attribute.
:returns: A dictionary with the following keys and values:
@@ -528,7 +536,7 @@ def order_stock_limit(
Raises an exception if order fails.
"""
raise NotImplementedError(
- f"{type(self).__name__} does not support this broker method: `order_limit`."
+ f"{type(self).__name__} does not support this broker method: `order_stock_limit`."
)
def order_crypto_limit(
@@ -556,7 +564,7 @@ def order_crypto_limit(
Raises an exception if order fails.
"""
raise NotImplementedError(
- f"{type(self).__name__} does not support this broker method: `order_limit`."
+ f"{type(self).__name__} does not support this broker method: `order_crypto_limit`."
)
def order_option_limit(
@@ -596,7 +604,12 @@ def order_option_limit(
# These do not need to be re-implemented in a subclass
def buy(
- self, symbol: str, quantity: int, limit_price: float, in_force: str = "gtc", extended: bool = False
+ self,
+ symbol: str,
+ quantity: int,
+ limit_price: float,
+ in_force: str = "gtc",
+ extended: bool = False,
):
"""
Buys the specified asset.
@@ -653,7 +666,7 @@ def sell(
:returns: The result of order_limit(). Returns None if there is an issue with the parameters.
"""
-
+
debugger.debug(f"{type(self).__name__} ordered a sell of {quantity} {symbol}")
typ = symbol_type(symbol)
@@ -707,8 +720,8 @@ def sell(
# if total_price >= buy_power:
# debugger.warning(
- # "Not enough buying power.\n" +
- # f"Total price ({price} * {quantity} * 1.05 = {limit_price*quantity}) exceeds buying power {buy_power}.\n" +
+ # "Not enough buying power.\n" +
+ # f"Total price ({price} * {quantity} * 1.05 = {limit_price*quantity}) exceeds buying power {buy_power}.\n" +
# "Reduce purchase quantity or increase buying power."
# )
@@ -817,8 +830,8 @@ def __init__(self, path: str = None):
self.block_queue = {}
self.first = True
- def setup(self, interval: Dict, trader=None, trader_main=None) -> None:
- super().setup(interval, trader, trader_main)
+ def setup(self, interval: Dict, trader_main=None) -> None:
+ super().setup(interval, trader_main)
self.blocker = {}
def start(self):
@@ -878,15 +891,15 @@ def timeout(self):
self.flush()
def flush(self):
- # For missing data, repeat the existing one
+ # For missing data, return a OHLC with all zeroes.
self.block_lock.acquire()
for n in self.needed:
- data = (
- self.trader.storage.load(n, self.interval[n]["interval"])
- .iloc[[-1]]
- .copy()
+ data = pd.DataFrame(
+ {"open": 0, "high": 0, "low": 0, "close": 0, "volume": 0},
+ index=[self.timestamp],
)
- data.index = [self.timestamp]
+
+ data.columns = pd.MultiIndex.from_product([[n], data.columns])
self.block_queue[n] = data
self.block_lock.release()
self.trader_main(self.block_queue)
diff --git a/harvest/api/alpaca.py b/harvest/api/alpaca.py
index c30f9b29..f1a52f24 100644
--- a/harvest/api/alpaca.py
+++ b/harvest/api/alpaca.py
@@ -76,8 +76,8 @@ async def update_data(self, bar):
else:
self.data_lock.release()
- def setup(self, interval: Dict, trader=None, trader_main=None):
- super().setup(interval, trader, trader_main)
+ def setup(self, interval: Dict, trader_main=None):
+ super().setup(interval, trader_main)
self.watch_stock = []
self.watch_crypto = []
diff --git a/harvest/api/dummy.py b/harvest/api/dummy.py
index 87083230..e30a6f63 100644
--- a/harvest/api/dummy.py
+++ b/harvest/api/dummy.py
@@ -27,21 +27,22 @@ class DummyStreamer(API):
Interval.HR_1,
Interval.DAY_1,
]
- default_now = dt.datetime(year=2000, month=1, day=1, hour=0, minute=0)
+
+ default_timestamp = dt.datetime(year=2000, month=1, day=1, hour=0, minute=0)
def __init__(
self,
- path: str = None,
- now: dt.datetime = default_now,
+ timestamp: dt.datetime = default_timestamp,
realistic_times: bool = False,
):
- self.trader = None
+
+ super().__init__(None)
+
self.trader_main = None
self.realistic_times = realistic_times
# Set the current time
- self._set_now(now)
- self.timestamp = self.now
+ self._set_timestamp(timestamp)
# Used so `fetch_price_history` can work without running `setup`
self.interval = self.interval_list[0]
# Store random values and generates for each asset tot make `fetch_price_history` fixed
@@ -65,7 +66,7 @@ def fetch_latest_stock_price(self) -> Dict[str, pd.DataFrame]:
"""
results = {}
- today = self.now
+ today = self.timestamp
last = today - dt.timedelta(days=3)
for symbol in self.interval:
@@ -83,7 +84,7 @@ def fetch_latest_crypto_price(self) -> Dict[str, pd.DataFrame]:
"""
results = {}
- today = self.now
+ today = self.timestamp
last = today - dt.timedelta(days=3)
for symbol in self.interval:
if is_crypto(symbol):
@@ -94,6 +95,9 @@ def fetch_latest_crypto_price(self) -> Dict[str, pd.DataFrame]:
# -------------- Streamer methods -------------- #
+ def get_current_time(self):
+ return self.timestamp
+
def fetch_price_history(
self,
symbol: str,
@@ -109,14 +113,14 @@ def fetch_price_history(
Interval.MIN_15,
Interval.MIN_30,
]:
- start = self.now - dt.timedelta(days=2)
+ start = self.timestamp - dt.timedelta(days=2)
elif interval == Interval.HR_1:
- start = self.now - dt.timedelta(days=14)
+ start = self.timestamp - dt.timedelta(days=14)
else:
- start = self.now - dt.timedelta(days=365)
+ start = self.timestamp - dt.timedelta(days=365)
if end is None:
- end = self.now
+ end = self.timestamp
if start.tzinfo is None or start.tzinfo.utcoffset(start) is None:
start = pytz.utc.localize(start)
@@ -168,7 +172,7 @@ def fetch_price_history(
self.randomness[symbol + "_rng"] = rng
# The inital price is arbitarly calculated from the first change in price
- start_price = 1000 * (self.randomness[symbol][0] + 0.51)
+ start_price = 1000 * (self.randomness[symbol][0] + 0.501)
times = []
current_time = start
@@ -218,7 +222,6 @@ def fetch_price_history(
results.columns = pd.MultiIndex.from_product([[symbol], results.columns])
results = aggregate_df(results, interval)
-
return results
# TODO: Generate dummy option data
@@ -227,7 +230,7 @@ def fetch_option_market_data(self, symbol: str):
# This is a placeholder so Trader doesn't crash
message = hashlib.sha256()
message.update(symbol.encode("utf-8"))
- message.update(str(self.now).encode("utf-8"))
+ message.update(str(self.timestamp).encode("utf-8"))
hsh = message.digest()
price = int.from_bytes(hsh[:4], "big") / (2 ** 32)
price = (price + 1) * 1.5
@@ -267,16 +270,16 @@ def fetch_option_market_data(self, symbol: str):
# ------------- Helper methods ------------- #
- def _set_now(self, current_datetime: dt.datetime) -> None:
+ def _set_timestamp(self, current_datetime: dt.datetime) -> None:
if (
current_datetime.tzinfo is None
or current_datetime.tzinfo.utcoffset(current_datetime) is None
):
- self.now = pytz.utc.localize(current_datetime)
+ self.timestamp = pytz.utc.localize(current_datetime)
else:
- self.now = current_datetime
+ self.timestamp = current_datetime
def tick(self) -> None:
- self.now += interval_to_timedelta(self.poll_interval)
+ self.timestamp += interval_to_timedelta(self.poll_interval)
if not self.trader_main == None:
self.main()
diff --git a/harvest/api/kraken.py b/harvest/api/kraken.py
index a6a8cb91..4b01cba0 100644
--- a/harvest/api/kraken.py
+++ b/harvest/api/kraken.py
@@ -14,16 +14,10 @@
class Kraken(API):
- interval_list = [
- Interval.MIN_1,
- Interval.MIN_5,
- Interval.MIN_15,
- Interval.MIN_30,
- Interval.HR_1,
- Interval.DAY_1,
- ]
+ interval_list = [Interval.MIN_1, Interval.MIN_5, Interval.HR_1, Interval.DAY_1]
+
crypto_ticker_to_kraken_names = {
- "BTC": "XXBT",
+ "BTC": "XXBTZ",
"ETH": "XETH",
"ADA": "ADA",
"USDT": "USDT",
@@ -113,19 +107,24 @@ class Kraken(API):
"OXT": "OXT",
}
+ kraken_names_to_crypto_ticker = {
+ v: k for k, v in crypto_ticker_to_kraken_names.items()
+ }
+
def __init__(self, path: str = None):
super().__init__(path)
self.api = krakenex.API(self.config["api_key"], self.config["secret_key"])
- def setup(self, watch: List[str], interval: str, trader=None, trader_main=None):
+ def setup(self, interval: Dict, trader_main=None):
+ super().setup(interval, trader_main)
self.watch_crypto = []
- if is_crypto(s):
- self.watch_crypto.append(s)
- else:
- debugger.error("Kraken does not support stocks.")
+ for sym in interval:
+ if is_crypto(sym):
+ self.watch_crypto.append(sym)
+ else:
+ debugger.warning(f"Kraken does not support stocks. Ignoring {sym}.")
self.option_cache = {}
- super().setup(watch, interval, interval, trader, trader_main)
def exit(self):
self.option_cache = {}
@@ -139,10 +138,13 @@ def main(self):
@API._exception_handler
def fetch_latest_crypto_price(self):
dfs = {}
- for symbol in self.watch_cryptos:
+ for symbol in self.watch_crypto:
dfs[symbol] = self.fetch_price_history(
- symbol, self.interval, now() - dt.timedelta(days=7), now()
- ).iloc[[0]]
+ symbol,
+ self.interval[symbol]["interval"],
+ now() - dt.timedelta(days=7),
+ now(),
+ ).iloc[[-1]]
return dfs
# -------------- Streamer methods -------------- #
@@ -151,7 +153,7 @@ def fetch_latest_crypto_price(self):
def fetch_price_history(
self,
symbol: str,
- interval: str,
+ interval: Interval,
start: dt.datetime = None,
end: dt.datetime = None,
):
@@ -175,17 +177,16 @@ def fetch_price_history(
f"Interval {interval} not in interval list. Possible options are: {self.interval_list}"
)
val, unit = expand_interval(interval)
- return self.get_data_from_kraken(symbol, val, unit, start, end)
+ df = self.get_data_from_kraken(symbol, val, unit, start, end)
+
+ return df
- @API._exception_handler
def fetch_chain_info(self, symbol: str):
raise NotImplementedError("Kraken does not support options.")
- @API._exception_handler
def fetch_chain_data(self, symbol: str, date: dt.datetime):
raise NotImplementedError("Kraken does not support options.")
- @API._exception_handler
def fetch_option_market_data(self, occ_symbol: str):
raise NotImplementedError("Kraken does not support options.")
@@ -203,51 +204,88 @@ def fetch_option_positions(self):
@API._exception_handler
def fetch_crypto_positions(self):
- return self.get_result(self.api.query_private("OpenOrders"))
+ positions = self.get_result(self.api.query_private("OpenPositions"))
+
+ def fmt(crypto: Dict[str, Any]):
+ # Remove the currency
+ symbol = crypto["pair"][:-4]
+ # Convert from kraken name to crypto currency ticker
+ symbol = kraken_name_to_crypto_ticker.get(symbol)
+ return {
+ "symbol": "@" + symbol,
+ "avg_price": float(crypto["cost"]) / float(crypto["vol"]),
+ "quantity": float(crypto["vol"]),
+ "kraken": crypto,
+ }
+
+ return [fmt(pos) for pos in positions]
- @API._exception_handler
def update_option_positions(self, positions: List[Any]):
- raise NotImplementedError("Kraken does not support options.")
+ debugger.error("Kraken does not support options. Doing nothing.")
@API._exception_handler
def fetch_account(self):
- return self.get_result(self.api.query_private("Balance"))
+ account = self.get_result(self.api.query_private("Balance"))
+ if account is None:
+ equity = 0
+ cash = 0
+ else:
+ equity = sum(float(v) for k, v in account.items() if k != "ZUSD")
+ cash = account.get("ZUSD", 0)
+ return {
+ "equity": equity,
+ "cash": cash,
+ "buying_power": equity + cash,
+ "multiplier": 1,
+ "kraken": account,
+ }
- @API._exception_handler
def fetch_stock_order_status(self, order_id: str):
return NotImplementedError("Kraken does not support stocks.")
- @API._exception_handler
- def fetch_option_order_status(self, id):
+ def fetch_option_order_status(self, order_id: str):
raise Exception("Kraken does not support options.")
@API._exception_handler
- def fetch_crypto_order_status(self, id: str):
- closed_orders = self.get_result(self.api.query_private("ClosedOrders"))
- orders = closed_orders["closed"] + self.fetch_order_queue()
- if id in orders.keys():
- return orders.get(id)
- raise Exception(f"{id} not found in your orders.")
+ def fetch_crypto_order_status(self, order_id: str):
+ order = self.api.query_private("QueryOrders", {"txid": order_id})
+ symbol = kraken_names_to_crypto_ticker.get(crypto["descr"]["pair"][:-4])
+ return {
+ "type": "CRYPTO",
+ "symbol": "@" + symbol,
+ "id": crypto.key(),
+ "quantity": float(crypto["vol"]),
+ "filled_quantity": float(crypto["vol_exec"]),
+ "side": crypto["descr"]["type"],
+ "time_in_force": None,
+ "status": crypto["status"],
+ "kraken": crypto,
+ }
# --------------- Methods for Trading --------------- #
@API._exception_handler
def fetch_order_queue(self):
open_orders = self.get_result(self.api.query_private("OpenOrders"))
- return open_orders["open"]
-
- def order_stock_limit(
- self,
- symbol: str,
- quantity: float,
- limit_price: float,
- order_type: str,
- time_in_force: str,
- order_id: str = None,
- ):
- raise NotImplementedError("Kraken does not support stocks.")
-
- def order_crypto_limit(
+ open_orders = open_orders["open"]
+
+ def fmt(crypto: Dict[str, Any]):
+ symbol = kraken_names_to_crypto_ticker.get(crypto["descr"]["pair"][:-4])
+ return {
+ "type": "CRYPTO",
+ "symbol": "@" + symbol,
+ "id": crypto.key(),
+ "quantity": float(crypto["vol"]),
+ "filled_quantity": float(crypto["vol_exec"]),
+ "side": crypto["descr"]["type"],
+ "time_in_force": None,
+ "status": crypto["status"],
+ "kraken": crypto,
+ }
+
+ return [fmt(order) for order in open_orders]
+
+ def order_limit(
self,
side: str,
symbol: str,
@@ -256,10 +294,12 @@ def order_crypto_limit(
in_force: str = "gtc",
extended: bool = False,
):
+ if is_crypto(symbol):
+ symbol = ticker_to_kraken(symbol)
+ else:
+ raise Exception("Kraken does not support stocks.")
- symbol = self.ticker_to_kraken(symbol)
-
- return self.get_result(
+ order = self.get_result(
self.api.query_private(
"AddOrder",
{
@@ -271,6 +311,13 @@ def order_crypto_limit(
)
)
+ return {
+ "type": "CRYPTO",
+ "id": order["txid"],
+ "symbol": symbol,
+ "kraken": order,
+ }
+
def order_option_limit(
self,
side: str,
@@ -336,6 +383,29 @@ def _format_df(self, df: pd.DataFrame, symbol: str):
return df.dropna()
+ def ticker_to_kraken(self, ticker: str):
+ if not is_crypto(ticker):
+ raise Exception("Kraken does not support stocks.")
+
+ if ticker[1:] in self.crypto_ticker_to_kraken_names:
+ # Currently Harvest supports trades for USD and not other currencies.
+ kraken_ticker = self.crypto_ticker_to_kraken_names.get(ticker[1:]) + "USD"
+ asset_pairs = self.get_result(self.api.query_public("AssetPairs")).keys()
+ if kraken_ticker in asset_pairs:
+ return kraken_ticker
+ else:
+ raise Exception(f"{kraken_ticker} is not a valid asset pair.")
+ else:
+ raise Exception(f"Kraken does not support ticker {ticker}.")
+
+ def get_result(self, response: Dict[str, Any]):
+ """Given a kraken response from an endpoint, either raise an error if an
+ error exists or return the data in the results key.
+ """
+ if len(response["error"]) > 0:
+ raise Exception("\n".join(response["error"]))
+ return response.get("result", None)
+
def create_secret(self, path: str) -> bool:
import harvest.wizard as wizard
@@ -382,26 +452,3 @@ def create_secret(self, path: str) -> bool:
)
return True
-
- def ticker_to_kraken(self, ticker: str):
- if not is_crypto(ticker):
- raise Exception("Kraken does not support stocks.")
-
- if ticker[1:] not in self.crypto_ticker_to_kraken_names:
- raise Exception(f"Kraken does not support ticker {ticker}.")
-
- # Currently Harvest supports trades for USD and not other currencies.
- kraken_ticker = self.crypto_ticker_to_kraken_names.get(ticker[1:]) + "USD"
- asset_pairs = self.get_result(self.api.query_public("AssetPairs")).keys()
- if kraken_ticker in asset_pairs:
- return kraken_ticker
- else:
- raise Exception(f"{kraken_ticker} is not a valid asset pair.")
-
- def get_result(self, response: Dict[str, Any]):
- """Given a kraken response from an endpoint, either raise an error if an
- error exists or return the data in the results key.
- """
- if len(response["error"]) > 0:
- raise Exception("\n".join(response["error"]))
- return response["result"]
diff --git a/harvest/api/paper.py b/harvest/api/paper.py
index 4661b437..99d00e14 100644
--- a/harvest/api/paper.py
+++ b/harvest/api/paper.py
@@ -9,6 +9,7 @@
# Submodule imports
from harvest.api._base import API
+from harvest.api.dummy import DummyStreamer
from harvest.utils import *
@@ -27,7 +28,7 @@ class PaperBroker(API):
Interval.DAY_1,
]
- def __init__(self, account_path: str = None, commission_fee=0):
+ def __init__(self, account_path: str = None, commission_fee=0, streamer=None):
"""
:commission_fee: When this is a number it is assumed to be a flat price
on all buys and sells of assets. When this is a string formatted as
@@ -49,6 +50,7 @@ def __init__(self, account_path: str = None, commission_fee=0):
self.multiplier = 1
self.commission_fee = commission_fee
self.id = 0
+ self.streamer = DummyStreamer() if streamer is None else streamer
if account_path:
with open(account_path, "r") as f:
@@ -64,8 +66,8 @@ def __init__(self, account_path: str = None, commission_fee=0):
for crypto in account["cryptos"]:
self.cryptos.append(crypto)
- def setup(self, interval, trader=None, trader_main=None):
- super().setup(interval, trader, trader_main)
+ def setup(self, interval, trader_main=None):
+ super().setup(interval, trader_main)
# -------------- Streamer methods -------------- #
@@ -89,11 +91,7 @@ def fetch_crypto_positions(self) -> List[Dict[str, Any]]:
def update_option_positions(self, positions) -> List[Dict[str, Any]]:
for r in self.options:
occ_sym = r["symbol"]
-
- if self.trader is None:
- price = self.fetch_option_market_data(occ_sym)["price"]
- else:
- price = self.trader.streamer.fetch_option_market_data(occ_sym)["price"]
+ price = self.streamer.fetch_option_market_data(occ_sym)["price"]
r["current_price"] = price
r["market_value"] = price * r["quantity"] * 100
@@ -112,17 +110,12 @@ def fetch_stock_order_status(self, id: int) -> Dict[str, Any]:
ret = next(r for r in self.orders if r["id"] == id)
sym = ret["symbol"]
- if self.trader is None:
- price = self.streamer.fetch_price_history(
- sym,
- self.interval[sym]["interval"],
- dt.datetime.now() - dt.timedelta(days=7),
- dt.datetime.now(),
- )[sym]["close"][-1]
- else:
- price = self.trader.storage.load(sym, self.interval[sym]["interval"])[sym][
- "close"
- ][-1]
+ price = self.streamer.fetch_price_history(
+ sym,
+ self.interval[sym]["interval"],
+ self.streamer.get_current_time() - dt.timedelta(days=7),
+ self.streamer.get_current_time(),
+ )[sym]["close"][-1]
qty = ret["quantity"]
original_price = price * qty
@@ -160,7 +153,7 @@ def fetch_stock_order_status(self, id: int) -> Dict[str, Any]:
self.orders.remove(ret)
ret = ret_1
ret["status"] = "filled"
- ret["filled_time"] = self.trader.timestamp
+ ret["filled_time"] = self.streamer.get_current_time()
ret["filled_price"] = price
else:
if pos is None:
@@ -178,7 +171,7 @@ def fetch_stock_order_status(self, id: int) -> Dict[str, Any]:
self.orders.remove(ret)
ret = ret_1
ret["status"] = "filled"
- ret["filled_time"] = self.trader.timestamp
+ ret["filled_time"] = self.streamer.get_current_time()
ret["filled_price"] = price
self.equity = self._calc_equity()
@@ -194,10 +187,7 @@ def fetch_option_order_status(self, id: int) -> Dict[str, Any]:
sym = ret["base_symbol"]
occ_sym = ret["symbol"]
- if self.trader is None:
- price = self.streamer.fetch_option_market_data(occ_sym)["price"]
- else:
- price = self.trader.streamer.fetch_option_market_data(occ_sym)["price"]
+ price = self.streamer.fetch_option_market_data(occ_sym)["price"]
qty = ret["quantity"]
original_price = price * qty
@@ -243,7 +233,7 @@ def fetch_option_order_status(self, id: int) -> Dict[str, Any]:
self.cash -= actual_price
self.buying_power -= actual_price
ret["status"] = "filled"
- ret["filled_time"] = self.trader.timestamp
+ ret["filled_time"] = self.streamer.get_current_time()
ret["filled_price"] = price
debugger.debug(f"After BUY: {self.buying_power}")
ret_1 = ret.copy()
@@ -265,7 +255,7 @@ def fetch_option_order_status(self, id: int) -> Dict[str, Any]:
if pos["quantity"] < 1e-8:
self.options.remove(pos)
ret["status"] = "filled"
- ret["filled_time"] = self.trader.timestamp
+ ret["filled_time"] = self.streamer.get_current_time()
ret["filled_price"] = price
ret_1 = ret.copy()
self.orders.remove(ret)
@@ -324,7 +314,7 @@ def order_crypto_limit(
):
data = {
"type": "CRYPTO",
- "symbol": '@'+symbol,
+ "symbol": "@" + symbol,
"quantity": quantity,
"filled_qty": quantity,
"limit_price": limit_price,
diff --git a/harvest/api/polygon.py b/harvest/api/polygon.py
index 717fed54..6de3b270 100644
--- a/harvest/api/polygon.py
+++ b/harvest/api/polygon.py
@@ -21,7 +21,7 @@ def __init__(self, path: str = None, is_basic_account: bool = False):
super().__init__(path)
self.basic = is_basic_account
- def setup(self, interval, trader=None, trader_main=None):
+ def setup(self, interval, trader_main=None):
self.watch_stock = []
self.watch_crypto = []
@@ -32,7 +32,7 @@ def setup(self, interval, trader=None, trader_main=None):
self.watch_stock.append(sym)
self.option_cache = {}
- super().setup(interval, trader, trader_main)
+ super().setup(interval, trader_main)
def exit(self):
self.option_cache = {}
@@ -47,7 +47,9 @@ def main(self):
return
for s in combo:
- df = self.fetch_price_history(s, Interval.MIN_1, now() - dt.timedelta(days=1), now()).iloc[-1]
+ df = self.fetch_price_history(
+ s, Interval.MIN_1, now() - dt.timedelta(days=1), now()
+ ).iloc[-1]
df_dict[s] = df
debugger.debug(df)
self.trader_main(df_dict)
diff --git a/harvest/api/robinhood.py b/harvest/api/robinhood.py
index f98ae041..f67c755d 100644
--- a/harvest/api/robinhood.py
+++ b/harvest/api/robinhood.py
@@ -42,9 +42,9 @@ def refresh_cred(self):
debugger.debug("Logged into Robinhood...")
# @API._run_once
- def setup(self, interval, trader=None, trader_main=None):
+ def setup(self, interval, trader_main=None):
- super().setup(interval, trader, trader_main)
+ super().setup(interval, trader_main)
# Robinhood only supports 15SEC, 1MIN interval for crypto
for sym in interval:
@@ -388,7 +388,7 @@ def fetch_option_order_status(self, id):
"time_in_force": ret["time_in_force"],
"status": ret["state"],
"filled_time": filled_time,
- "filled_price": filled_price
+ "filled_price": filled_price,
}
@API._exception_handler
@@ -551,7 +551,7 @@ def order_crypto_limit(
return {
"type": "CRYPTO",
"id": ret["id"],
- "symbol": '@'+ symbol,
+ "symbol": "@" + symbol,
}
except:
debugger.error("Error while placing order.\nReturned: {ret}", exc_info=True)
@@ -630,9 +630,9 @@ def _format_df(
df.columns = pd.MultiIndex.from_product([watch, df.columns])
return df.dropna()
-
+
def _rh_datestr_to_datetime(self, date_str: str):
- date_str = date_str[:-3]+date_str[-2:]
+ date_str = date_str[:-3] + date_str[-2:]
return dt.datetime.strptime(date_str, "%Y-%m-%dT%H:%M:%S.%f%z")
def create_secret(self, path):
diff --git a/harvest/api/webull.py b/harvest/api/webull.py
index 323f7700..1110e980 100644
--- a/harvest/api/webull.py
+++ b/harvest/api/webull.py
@@ -68,8 +68,8 @@ def enter_live_trade_pin(self):
return
return self.api.get_trade_token(self.config["wb_trade_pin"])
- def setup(self, interval: Dict, trader=None, trader_main=None):
- super().setup(interval, trader, trader_main)
+ def setup(self, interval: Dict, trader_main=None):
+ super().setup(interval, trader_main)
self.watch_stock = []
self.watch_crypto = []
self.watch_crypto_fmt = []
diff --git a/harvest/api/yahoo.py b/harvest/api/yahoo.py
index 3f5eb42b..0d974902 100644
--- a/harvest/api/yahoo.py
+++ b/harvest/api/yahoo.py
@@ -3,6 +3,7 @@
from typing import Any, Dict, List, Tuple
# External libraries
+import tzlocal
import pandas as pd
import yfinance as yf
@@ -25,8 +26,8 @@ class YahooStreamer(API):
def __init__(self, path=None):
self.timestamp = now()
- def setup(self, interval: Dict, trader=None, trader_main=None):
- super().setup(interval, trader, trader_main)
+ def setup(self, interval: Dict, trader_main=None):
+ super().setup(interval, trader_main)
self.watch_ticker = {}
@@ -169,7 +170,8 @@ def fetch_chain_info(self, symbol: str):
return {
"id": "n/a",
"exp_dates": [
- convert_input_to_datetime(s, self.trader.timezone) for s in option_list
+ convert_input_to_datetime(s, tzlocal.get_localzone())
+ for s in option_list
],
"multiplier": 100,
}
diff --git a/harvest/cli.py b/harvest/cli.py
index ce65a3e0..bbe4b9d6 100644
--- a/harvest/cli.py
+++ b/harvest/cli.py
@@ -101,9 +101,8 @@ def raise_helper():
help="buys and sells assets on your behalf",
choices=list(brokers.keys()),
)
-# start_parser.add_argument(
-# "algos", nargs="+", help="paths to algorithms you want to run"
-# )
+
+# Directory with algos that you want to run, default is the current working directory.
start_parser.add_argument(
"-d",
"--directory",
@@ -117,9 +116,8 @@ def raise_helper():
# Parser for visualing data
visualize_parser = subparsers.add_parser("visualize")
-visualize_parser.add_argument(
- "path", help="path to harvest generated data file"
-)
+visualize_parser.add_argument("path", help="path to harvest generated data file")
+
def main():
"""
@@ -151,29 +149,32 @@ def start(args: argparse.Namespace, test: bool = False):
debug = args.debug
trader = LiveTrader(streamer=streamer, broker=broker, storage=storage, debug=debug)
- # algos is a list of paths to files that have user defined algos
+ # Get the directories.
directory = args.directory
print(f"Searching directory {directory}")
files = [fi for fi in listdir(directory) if isfile(join(directory, fi))]
print(f"Found files {files}")
+ # For each file in the directory...
for f in files:
names = f.split(".")
+ # Filter out non-python files.
if len(names) <= 1 or names[-1] != "py":
continue
name = "".join(names[:-1])
+ # ...open it...
with open(join(directory, f), "r") as algo_file:
firstline = algo_file.readline()
if firstline.find("HARVEST_SKIP") != -1:
print(f"Skipping {f}")
continue
- # load in the entire file
+ # ...load in the entire file and add the algo to the trader.
algo_path = os.path.realpath(join(directory, f))
spec = importlib.util.spec_from_file_location(name, algo_path)
algo = importlib.util.module_from_spec(spec)
spec.loader.exec_module(algo)
- # iterate though the variables and if a variable is a subclass of BaseAlgo instantiate it and added to the trader
+ # Iterate though the variables and if a variable is a subclass of BaseAlgo instantiate it and added to the trader.
for algo_cls in inspect.getmembers(algo):
k, v = algo_cls[0], algo_cls[1]
if inspect.isclass(v) and v != BaseAlgo and issubclass(v, BaseAlgo):
@@ -183,11 +184,17 @@ def start(args: argparse.Namespace, test: bool = False):
if not test:
trader.start()
+
def visualize(args: argparse.Namespace):
+ """
+ Read a csv or pickle file created by Harvest with ohlc data and graph the data.
+ :args: A Namespace object containing parsed user arguments.
+ """
import re
import pandas as pd
import mplfinance as mpf
+ # Open the file using the appropriate parser.
if args.path.endswith(".csv"):
df = pd.read_csv(args.path)
df["timestamp"] = pd.to_datetime(df["timestamp"])
@@ -200,23 +207,36 @@ def visualize(args: argparse.Namespace):
if df.empty:
print(f"No data found in {args.path}.")
+ sys.exit(2)
path = os.path.basename(args.path)
+ # File names are asset {ticker name}@{interval}.{file format}
file_search = re.search("^(@?[\w]+)@([\w]+).(csv|pickle)$", path)
symbol, interval = file_search.group(1), file_search.group(2)
open_price = df.iloc[0]["open"]
close_price = df.iloc[-1]["close"]
high_price = df["high"].max()
low_price = df["low"].min()
+ price_delta = close_price - open_price
+ price_delta_precent = 100 % (price_delta / open_price)
+ volume = df["volume"].sum()
print(f"{symbol} at {interval}")
- print("open", open_price)
- print("high", high_price)
- print("low", low_price)
- print("close", close_price)
- print("price change", close_price - open_price)
- mpf.plot(df, type="candle", volume=True, show_nontrading=True)
-
+ print(f"open\t{open_price}")
+ print(f"high\t{high_price}")
+ print(f"low\t{low_price}")
+ print(f"close\t{close_price}")
+ print(f"price change\t{price_delta}")
+ print(f"price change percentage\t{price_delta_precent}%")
+ print(f"volume\t{volume}")
+ mpf.plot(
+ df,
+ type="candle",
+ style="charles",
+ volume=True,
+ show_nontrading=True,
+ title=path,
+ )
def _get_storage(storage: str):
diff --git a/harvest/storage/__init__.py b/harvest/storage/__init__.py
index fe657b5e..40f3f3c6 100644
--- a/harvest/storage/__init__.py
+++ b/harvest/storage/__init__.py
@@ -1,5 +1,3 @@
from harvest.storage.base_storage import BaseStorage
from harvest.storage.csv_storage import CSVStorage
from harvest.storage.pickle_storage import PickleStorage
-
-from harvest.storage.base_logger import BaseLogger
diff --git a/harvest/storage/base_storage.py b/harvest/storage/base_storage.py
index 78f8cf1b..c234f320 100644
--- a/harvest/storage/base_storage.py
+++ b/harvest/storage/base_storage.py
@@ -43,6 +43,7 @@
as long as they implement the API properly.
"""
+
class BaseStorage:
"""
A basic storage that is thread safe and stores data in memory.
@@ -58,22 +59,22 @@ class BaseStorage:
]
def __init__(
- self,
- price_storage_size: int = 200,
- price_storage_limit: bool = True,
- transaction_storage_size: int = 200,
- transaction_storage_limit: bool = True,
- performance_storage_size: int = 200,
- performance_storage_limit: bool = True,
- ):
+ self,
+ price_storage_size: int = 200,
+ price_storage_limit: bool = True,
+ transaction_storage_size: int = 200,
+ transaction_storage_limit: bool = True,
+ performance_storage_size: int = 200,
+ performance_storage_limit: bool = True,
+ ):
"""
- queue_size: The maximum number of data points to store for asset price history.
+ queue_size: The maximum number of data points to store for asset price history.
This helps prevent the database from becoming infinitely large as time progresses.
- limit_size: Whether to limit the size of price history to queue_size.
+ limit_size: Whether to limit the size of price history to queue_size.
This may be set to False if the storage is being used for backtesting, in which case
you would want to store as much data as possible.
"""
- self.storage_lock = Lock() # Lock
+ self.storage_lock = Lock() # Lock
self.price_storage_size = price_storage_size
self.price_storage_limit = price_storage_limit
@@ -84,10 +85,17 @@ def __init__(
# BaseStorage uses a python dictionary to store the data,
# where key is asset symbol and value is a pandas dataframe.
- self.storage_price = {}
+ self.storage_price = {}
self.storage_transaction = pd.DataFrame(
- columns=["timestamp", "algorithm_name", "symbol", "side", "quantity", "price"]
+ columns=[
+ "timestamp",
+ "algorithm_name",
+ "symbol",
+ "side",
+ "quantity",
+ "price",
+ ]
)
self.storage_performance = {}
@@ -95,14 +103,6 @@ def __init__(
self.storage_performance[interval] = pd.DataFrame(
columns=["equity"], index=[]
)
-
- def setup(self, trader):
- """
- Sets up the storage
- """
- self.trader = trader
-
-
def store(
self, symbol: str, interval: Interval, data: pd.DataFrame, remove_duplicate=True
@@ -131,8 +131,11 @@ def store(
intervals[interval], data, remove_duplicate=remove_duplicate
)
if self.price_storage_limit:
- intervals[interval] = intervals[interval][-self.price_storage_size:]
+ intervals[interval] = intervals[interval][
+ -self.price_storage_size :
+ ]
except:
+ self.storage_lock.release()
raise Exception("Append Failure, case not found!")
else:
# Add the data as a new interval
@@ -150,13 +153,12 @@ def store(
cur_len = len(self.storage_price[symbol][interval])
if self.price_storage_limit and cur_len > self.price_storage_size:
# If we have more than N data points, remove the oldest data
- self.storage_price[symbol][interval] = self.storage_price[symbol][interval].iloc[
- -self.price_storage_size :
- ]
+ self.storage_price[symbol][interval] = self.storage_price[symbol][
+ interval
+ ].iloc[-self.price_storage_size :]
self.storage_lock.release()
-
def load(
self,
symbol: str,
@@ -190,7 +192,9 @@ def load(
]
intervals.sort(key=lambda interval_timedelta: interval_timedelta[1])
for interval_timedelta in intervals:
+ self.storage_lock.release()
data = self.load(symbol, interval_timedelta[0], start, end)
+ self.storage_lock.acquire()
if data is not None:
self.storage_lock.release()
return data
@@ -201,7 +205,7 @@ def load(
if start is None and end is None:
self.storage_lock.release()
- return data
+ return data
# If the start and end are not defined, then set them to the
# beginning and end of the data.
@@ -210,20 +214,21 @@ def load(
if end is None:
end = data.index[-1]
+ self.storage_lock.release()
return data.loc[start:end]
-
+
def store_transaction(
- self,
+ self,
timestamp: dt.datetime,
algorithm_name: str,
- symbol: str,
+ symbol: str,
side: str,
quantity: int,
- price: float
+ price: float,
) -> None:
self.storage_transaction.append(
[timestamp, algorithm_name, symbol, side, quantity, price],
- ignore_index=True
+ ignore_index=True,
)
def reset(self, symbol: str, interval: Interval):
@@ -234,7 +239,6 @@ def reset(self, symbol: str, interval: Interval):
self.storage_price[symbol][interval] = pd.DataFrame()
self.storage_lock.release()
-
def _append(
self,
current_data: pd.DataFrame,
@@ -267,22 +271,24 @@ def aggregate(
self.storage_lock.acquire()
data = self.storage_price[symbol][base]
self.storage_price[symbol][target] = self._append(
- self.storage_price[symbol][target], aggregate_df(data, target), remove_duplicate
+ self.storage_price[symbol][target],
+ aggregate_df(data, target),
+ remove_duplicate,
)
cur_len = len(self.storage_price[symbol][target])
if self.price_storage_limit and cur_len > self.price_storage_size:
- self.storage_price[symbol][target] = self.storage_price[symbol][target].iloc[
- -self.price_storage_size :
- ]
+ self.storage_price[symbol][target] = self.storage_price[symbol][
+ target
+ ].iloc[-self.price_storage_size :]
self.storage_lock.release()
- def init_performace_data(self, equity: float):
+ def init_performace_data(self, equity: float, timestamp):
for interval, days in self.performance_history_intervals:
self.storage_performance[interval] = pd.DataFrame(
- {'equity': [equity]}, index=[self.trader.timestamp]
+ {"equity": [equity]}, index=[timestamp]
)
- def add_performance_data(self, equity: float):
+ def add_performance_data(self, equity: float, timestamp):
"""
Adds the performance data to the storage.
@@ -290,43 +296,39 @@ def add_performance_data(self, equity: float):
It takes the current equity and adds it to each interval.
- :param equity: Current equity of the account.
+ :param equity: Current equity of the account.
"""
- cur_timestamp = self.trader.timestamp
# Performance history range up until '3 MONTHS' have the
- # same interval as the polling interval of the trader.
+ # same interval as the polling interval of the trader.
for interval, days in self.performance_history_intervals[0:3]:
df = self.storage_performance[interval]
- cutoff = cur_timestamp - dt.timedelta(days=days)
+ cutoff = timestamp - dt.timedelta(days=days)
if df.index[0] < cutoff:
df = df.loc[df.index >= cutoff]
- df = df.append(pd.DataFrame({"equity": [equity]}, index=[cur_timestamp]))
+ df = df.append(pd.DataFrame({"equity": [equity]}, index=[timestamp]))
self.storage_performance[interval] = df
# Performance history intervals after '3 MONTHS' are populated
# only for each day.
for interval, days in self.performance_history_intervals[3:5]:
df = self.storage_performance[interval]
- if df.index[-1].date() == cur_timestamp.date():
- df = df.iloc[:-1]
- df = df.append(pd.DataFrame({"equity": [equity]}, index=[cur_timestamp]))
+ if df.index[-1].date() == timestamp.date():
+ df = df.iloc[:-1]
+ df = df.append(pd.DataFrame({"equity": [equity]}, index=[timestamp]))
else:
- df = df.append(pd.DataFrame({"equity": [equity]}, index=[cur_timestamp]))
- cutoff = cur_timestamp - dt.timedelta(days=days)
+ df = df.append(pd.DataFrame({"equity": [equity]}, index=[timestamp]))
+ cutoff = timestamp - dt.timedelta(days=days)
if df.index[0] < cutoff:
df = df.loc[df.index >= cutoff]
self.storage_performance[interval] = df
-
+
df = self.storage_performance["ALL"]
- if df.index[-1].date() == cur_timestamp.date():
- df = df.iloc[:-1]
- df = df.append(pd.DataFrame({"equity": [equity]}, index=[cur_timestamp]))
+ if df.index[-1].date() == timestamp.date():
+ df = df.iloc[:-1]
+ df = df.append(pd.DataFrame({"equity": [equity]}, index=[timestamp]))
self.storage_performance["ALL"] = df
debugger.debug("Performance data added")
for k in self.storage_performance:
debugger.debug(f"{k}:\n {self.storage_performance[k]}")
-
-
-
diff --git a/harvest/storage/csv_storage.py b/harvest/storage/csv_storage.py
index 0a600acc..b03859e7 100644
--- a/harvest/storage/csv_storage.py
+++ b/harvest/storage/csv_storage.py
@@ -61,7 +61,7 @@ def store(
if not data.empty:
self.storage_lock.acquire()
- self.storage[symbol][interval][symbol].to_csv(
+ self.storage_price[symbol][interval][symbol].to_csv(
self.save_dir + f"/{symbol}@{interval_enum_to_string(interval)}.csv"
)
self.storage_lock.release()
diff --git a/harvest/storage/pickle_storage.py b/harvest/storage/pickle_storage.py
index fb0a104d..5bf4462a 100644
--- a/harvest/storage/pickle_storage.py
+++ b/harvest/storage/pickle_storage.py
@@ -63,7 +63,7 @@ def store(
if not data.empty and save_pickle:
self.storage_lock.acquire()
- self.storage[symbol][interval].to_pickle(
+ self.storage_price[symbol][interval].to_pickle(
self.save_dir + f"/{symbol}@{interval_enum_to_string(interval)}.pickle"
)
self.storage_lock.release()
diff --git a/harvest/trader/tester.py b/harvest/trader/tester.py
index 4c16daea..a9513914 100644
--- a/harvest/trader/tester.py
+++ b/harvest/trader/tester.py
@@ -15,7 +15,6 @@
import harvest.trader.trader as trader
from harvest.api.yahoo import YahooStreamer
from harvest.api.paper import PaperBroker
-from harvest.storage import BaseLogger
from harvest.utils import *
@@ -67,8 +66,8 @@ def start(
a.config()
self._setup(source, interval, aggregations, path, start, end, period)
- self.broker.setup(self.interval, self, self.main)
- self.streamer.setup(self.interval, self, self.main)
+ self.broker.setup(self.interval, self.main)
+ self.streamer.setup(self.interval, self.main)
for a in self.algo:
a.setup()
@@ -119,7 +118,7 @@ def _setup(
common_end = None
for s in self.interval:
for i in [self.interval[s]["interval"]] + self.interval[s]["aggregations"]:
- df = self.storage.load(s, i, no_slice=True)
+ df = self.storage.load(s, i)
df = pandas_datetime_to_utc(df, self.timezone)
if common_start is None or df.index[0] > common_start:
common_start = df.index[0]
@@ -147,7 +146,7 @@ def _setup(
for s in self.interval:
for i in [self.interval[s]["interval"]] + self.interval[s]["aggregations"]:
- df = self.storage.load(s, i, no_slice=True).copy()
+ df = self.storage.load(s, i).copy()
df = df.loc[start:end]
self.storage.reset(s, i)
self.storage.store(s, i, df)
@@ -199,7 +198,7 @@ def _setup(
debugger.debug("Formatting complete")
for sym in self.interval:
for agg in self.interval[sym]["aggregations"]:
- data = self.storage.load(sym, int(agg) - 16, no_slice=True)
+ data = self.storage.load(sym, int(agg) - 16)
data = pandas_datetime_to_utc(data, self.timezone)
self.storage.store(
sym,
@@ -220,13 +219,13 @@ def _setup(
self.df[sym] = {}
inter = self.interval[sym]["interval"]
interval_txt = interval_enum_to_string(inter)
- df = self.storage.load(sym, inter, no_slice=True)
+ df = self.storage.load(sym, inter)
self.df[sym][inter] = df.copy()
for agg in self.interval[sym]["aggregations"]:
# agg_txt = interval_enum_to_string(agg)
# agg_txt = f"{interval_txt}+{agg_txt}"
- df = self.storage.load(sym, int(agg) - 16, no_slice=True)
+ df = self.storage.load(sym, int(agg) - 16)
self.df[sym][int(agg) - 16] = df.copy()
# Trim data so start and end dates match between assets and intervals
@@ -329,7 +328,7 @@ def run_backtest(self):
df_dict[sym] = self.df[sym][inter].loc[self.timestamp]
update = self._update_order_queue()
- self._update_stats(df_dict, new=update, option_update=True)
+ self._update_position_cache(df_dict, new=update, option_update=True)
for sym in self.interval:
inter = self.interval[sym]["interval"]
if is_freq(self.timestamp, inter):
diff --git a/harvest/trader/trader.py b/harvest/trader/trader.py
index 529341f8..9b1c1e05 100644
--- a/harvest/trader/trader.py
+++ b/harvest/trader/trader.py
@@ -19,7 +19,6 @@
from harvest.api.dummy import DummyStreamer
from harvest.api.paper import PaperBroker
from harvest.storage import BaseStorage
-from harvest.storage import BaseLogger
from harvest.server import Server
@@ -47,10 +46,7 @@ def __init__(self, streamer=None, broker=None, storage=None, debug=False):
self._set_streamer_broker(streamer, broker)
# Initialize the storage
- self.storage = (
- BaseStorage() if storage is None else storage
- )
- self.storage.setup(self)
+ self.storage = BaseStorage() if storage is None else storage
self._init_attributes()
self._setup_debugger(debug)
@@ -78,9 +74,6 @@ def _init_attributes(self):
signal(SIGINT, self.exit)
- # Initialize timestamp
- self.timestamp = self.streamer.timestamp
-
self.watchlist_global = [] # List of securities specified in this class
self.algo = [] # List of algorithms to run.
self.account = {} # Local cache of account data.
@@ -89,7 +82,6 @@ def _init_attributes(self):
self.crypto_positions = [] # Local cache of current crypto positions.
self.order_queue = [] # Queue of unfilled orders.
- self.logger = BaseLogger()
self.server = Server(self) # Initialize the web interface server
self.timezone = tzlocal.get_localzone()
@@ -148,13 +140,15 @@ def start(
# Initialize the account
self._setup_account()
- self.storage.init_performace_data(self.account["equity"])
+ self.storage.init_performace_data(
+ self.account["equity"], self.streamer.timestamp
+ )
- self.broker.setup(self.interval, self, self.main)
+ self.broker.setup(self.interval, self.main)
if self.broker != self.streamer:
# Only call the streamer setup if it is a different
# instance than the broker otherwise some brokers can fail!
- self.streamer.setup(self.interval, self, self.main)
+ self.streamer.setup(self.interval, self.main)
# Initialize the storage
self._storage_init(all_history)
@@ -274,19 +268,23 @@ def _storage_init(self, all_history: bool):
start = None if all_history else now() - dt.timedelta(days=3)
df = self.streamer.fetch_price_history(sym, inter, start)
self.storage.store(sym, inter, df)
-
+
# ================== Functions for main routine =====================
def main(self, df_dict):
"""
Main loop of the Trader.
"""
- self.timestamp = self.streamer.timestamp
# Periodically refresh access tokens
- if self.timestamp.hour % 12 == 0 and self.timestamp.minute == 0:
+ if (
+ self.streamer.timestamp.hour % 12 == 0
+ and self.streamer.timestamp.minute == 0
+ ):
self.streamer.refresh_cred()
-
- self.storage.add_performance_data(self.account["equity"])
+
+ self.storage.add_performance_data(
+ self.account["equity"], self.streamer.timestamp
+ )
# Save the data locally
for sym in df_dict:
@@ -306,7 +304,7 @@ def main(self, df_dict):
new_algo = []
for a in self.algo:
- if not is_freq(self.timestamp, a.interval):
+ if not is_freq(self.streamer.timestamp, a.interval):
new_algo.append(a)
continue
try:
@@ -353,14 +351,16 @@ def _update_order_queue(self):
# TODO: handle cancelled orders
if order["status"] == "filled":
order_filled = True
- debugger.debug(f"Order {order['id']} filled at {order['filled_time']} at {order['filled_price']}")
+ debugger.debug(
+ f"Order {order['id']} filled at {order['filled_time']} at {order['filled_price']}"
+ )
self.storage.store_transaction(
order["filled_time"],
"N/A",
order["symbol"],
order["side"],
order["quantity"],
- order["filled_price"]
+ order["filled_price"],
)
else:
new_order.append(order)
@@ -407,6 +407,7 @@ def _update_position_cache(self, df_dict, new=False, option_update=False):
equity = net_value + self.account["cash"]
self.account["equity"] = equity
+ self.stock_positions = self.broker.fetch_stock_positions()
def _fetch_account_data(self):
pos = self.broker.fetch_stock_positions()
@@ -435,29 +436,33 @@ def buy(self, symbol: str, quantity: int, in_force: str, extended: bool):
if symbol_type(symbol) == "OPTION":
price = self.streamer.fetch_option_market_data(symbol)["price"]
else:
- price = self.storage.load(
- symbol, self.interval[symbol]["interval"]
- )[symbol]["close"][-1]
+ price = self.storage.load(symbol, self.interval[symbol]["interval"])[
+ symbol
+ ]["close"][-1]
limit_price = mark_up(price)
total_price = limit_price * quantity
+ debugger.warning(
+ f"Attempting to buy {quantity} shares of {symbol} at price {price} with price limit {limit_price} and a maximum total price of {total_price}"
+ )
+
if total_price >= buy_power:
debugger.error(
- "Not enough buying power.\n" +
- f"Total price ({price} * {quantity} * 1.05 = {limit_price*quantity}) exceeds buying power {buy_power}." +
- "Reduce purchase quantity or increase buying power."
+ "Not enough buying power.\n"
+ + f"Total price ({price} * {quantity} * 1.05 = {limit_price*quantity}) exceeds buying power {buy_power}."
+ + "Reduce purchase quantity or increase buying power."
)
return None
-
+
# TODO? Perform other checks
ret = self.broker.buy(symbol, quantity, limit_price, in_force, extended)
-
+
if ret is None:
debugger.debug("BUY failed")
return None
self.order_queue.append(ret)
- debugger.debug(f"BUY: {self.timestamp}, {symbol}, {quantity}")
+ debugger.debug(f"BUY: {self.streamer.timestamp}, {symbol}, {quantity}")
return ret
@@ -465,13 +470,15 @@ def sell(self, symbol: str, quantity: int, in_force: str, extended: bool):
# Check how many of the given asset we currently own
owned_qty = self.get_asset_quantity(symbol, exclude_pending_sell=True)
if quantity > owned_qty:
- debugger.debug("SELL failed: More quantities are being sold than currently owned.")
+ debugger.debug(
+ "SELL failed: More quantities are being sold than currently owned."
+ )
return None
-
+
if symbol_type(symbol) == "OPTION":
- price = self.trader.streamer.fetch_option_market_data(symbol)["price"]
+ price = self.streamer.fetch_option_market_data(symbol)["price"]
else:
- price = self.trader.storage.load(symbol, self.interval[symbol]["interval"])[
+ price = self.storage.load(symbol, self.interval[symbol]["interval"])[
symbol
]["close"][-1]
@@ -482,14 +489,13 @@ def sell(self, symbol: str, quantity: int, in_force: str, extended: bool):
debugger.debug("SELL failed")
return None
self.order_queue.append(ret)
- debugger.debug(f"SELL: {self.timestamp}, {symbol}, {quantity}")
+ debugger.debug(f"SELL: {self.streamer.timestamp}, {symbol}, {quantity}")
return ret
-
+
# ================ Helper Functions ======================
def get_asset_quantity(
- self, symbol: str = None,
- include_pending_buy = False,
- exclude_pending_sell = False ) -> float:
+ self, symbol: str = None, include_pending_buy=False, exclude_pending_sell=False
+ ) -> float:
"""Returns the quantity owned of a specified asset.
:param str? symbol: Symbol of asset. defaults to first symbol in watchlist
@@ -498,40 +504,34 @@ def get_asset_quantity(
"""
if symbol is None:
symbol = self.watchlist_global[0]
-
+
if typ := symbol_type(symbol) == "OPTION":
owned_qty = sum(
- p["quantity"]
- for p in self.option_positions
- if p["symbol"] == symbol
+ p["quantity"] for p in self.option_positions if p["symbol"] == symbol
)
elif typ == "CRYPTO":
owned_qty = sum(
- p["quantity"]
- for p in self.crypto_positions
- if p["symbol"] == symbol
+ p["quantity"] for p in self.crypto_positions if p["symbol"] == symbol
)
else:
owned_qty = sum(
- p["quantity"]
- for p in self.stock_positions
- if p["symbol"] == symbol
+ p["quantity"] for p in self.stock_positions if p["symbol"] == symbol
)
-
+
if include_pending_buy:
owned_qty += sum(
o["quantity"]
for o in self.order_queue
if o["symbol"] == symbol and o["side"] == "buy"
)
-
+
if exclude_pending_sell:
owned_qty -= sum(
o["quantity"]
for o in self.order_queue
if o["symbol"] == symbol and o["side"] == "sell"
)
-
+
return owned_qty
# def buy_option(self, symbol: str, quantity: int, in_force: str):
@@ -539,9 +539,9 @@ def get_asset_quantity(
# if ret is None:
# raise Exception("BUY failed")
# self.order_queue.append(ret)
- # debugger.debug(f"BUY: {self.timestamp}, {symbol}, {quantity}")
+ # debugger.debug(f"BUY: {self.streamer.timestamp}, {symbol}, {quantity}")
# debugger.debug(f"BUY order queue: {self.order_queue}")
- # self.logger.add_transaction(self.timestamp, "buy", "option", symbol, quantity)
+ # self.logger.add_transaction(self.streamer.timestamp, "buy", "option", symbol, quantity)
# return ret
# def sell_option(self, symbol: str, quantity: int, in_force: str):
@@ -561,9 +561,9 @@ def get_asset_quantity(
# if ret is None:
# raise Exception("SELL failed")
# self.order_queue.append(ret)
- # debugger.debug(f"SELL: {self.timestamp}, {symbol}, {quantity}")
+ # debugger.debug(f"SELL: {self.streamer.timestamp}, {symbol}, {quantity}")
# debugger.debug(f"SELL order queue: {self.order_queue}")
- # self.logger.add_transaction(self.timestamp, "sell", "option", symbol, quantity)
+ # self.logger.add_transaction(self.streamer.timestamp, "sell", "option", symbol, quantity)
# return ret
def set_algo(self, algo):
@@ -608,7 +608,7 @@ def __init__(self, streamer=None, storage=None, debug=False):
# If streamer is not specified, use YahooStreamer
self.streamer = YahooStreamer() if streamer is None else streamer
- self.broker = PaperBroker()
+ self.broker = PaperBroker(streamer=self.streamer)
self.storage = (
BaseStorage() if storage is None else storage
diff --git a/harvest/utils.py b/harvest/utils.py
index f8742d29..78677b74 100644
--- a/harvest/utils.py
+++ b/harvest/utils.py
@@ -12,6 +12,7 @@
# External Imports
import pandas as pd
+# Configure a logger used by all of Harvest.
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s : %(name)s : %(levelname)s : %(message)s",
@@ -246,6 +247,7 @@ def datetime_utc_to_local(date_time: dt.datetime, timezone: ZoneInfo) -> dt.date
Converts a datetime object in UTC to local time, represented as a
timezone naive datetime object.
"""
+ date_time = date_time.to_pydatetime()
new_tz = date_time.astimezone(timezone)
return new_tz.replace(tzinfo=None)
diff --git a/tests/test_algo.py b/tests/test_algo.py
index c657f4a0..913acab2 100644
--- a/tests/test_algo.py
+++ b/tests/test_algo.py
@@ -254,7 +254,7 @@ def test_buy_sell_option_auto(self, mock_mark_up):
p = t.option_positions[0]
self.assertEqual(p["symbol"], "X 110101C01000000")
- t.algo[0].sell_option()
+ t.algo[0].sell_all_options()
streamer.tick()
# p = t.stock_positions[0]
diff --git a/tests/test_api.py b/tests/test_api.py
index f109596a..2f6a3d0e 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -25,7 +25,7 @@ def setUpClass(self):
def test_timeout(self):
stream = StreamAPI()
stream.fetch_account = lambda: None
- stream.fetch_price_history = lambda x, y: pd.DataFrame()
+ stream.fetch_price_history = lambda x, y, z: pd.DataFrame()
stream.fetch_account = lambda: {"cash": 100, "equity": 100}
t = PaperTrader(stream)
stream.trader = t
@@ -62,20 +62,20 @@ def test_timeout(self):
data["A"]["A"]["close"][-1],
t.storage.load("A", Interval.MIN_1)["A"]["close"][-1],
)
- # Check if B has been duplicated
+ # Check if B has been set to 0
self.assertEqual(
b_cur["B"]["close"][-1],
t.storage.load("B", Interval.MIN_1)["B"]["close"][-2],
)
self.assertEqual(
- b_cur["B"]["close"][-1],
+ 0,
t.storage.load("B", Interval.MIN_1)["B"]["close"][-1],
)
def test_timeout_cancel(self):
stream = StreamAPI()
stream.fetch_account = lambda: None
- stream.fetch_price_history = lambda x, y: pd.DataFrame()
+ stream.fetch_price_history = lambda x, y, z: pd.DataFrame()
stream.fetch_account = lambda: {"cash": 100, "equity": 100}
t = PaperTrader(stream)
t.set_algo(BaseAlgo())
@@ -130,11 +130,7 @@ def test_timeout_cancel(self):
def test_exceptions(self):
api = API()
- try:
- api.create_secret("I dont exists")
- self.assertTrue(False)
- except Exception as e:
- self.assertEqual(str(e), "I dont exists was not found.")
+ self.assertEqual(api.create_secret("I dont exists"), False)
try:
api.fetch_price_history('A', Interval.MIN_1, now(), now())
@@ -185,10 +181,16 @@ def test_exceptions(self):
self.assertEqual(str(e), "API does not support this broker method: `fetch_crypto_order_status`.")
try:
- api.order_limit("buy", "A", 5, 7)
+ api.order_stock_limit("buy", "A", 5, 7)
+ self.assertTrue(False)
+ except NotImplementedError as e:
+ self.assertEqual(str(e), "API does not support this broker method: `order_stock_limit`.")
+
+ try:
+ api.order_crypto_limit("buy", "@A", 5, 7)
self.assertTrue(False)
except NotImplementedError as e:
- self.assertEqual(str(e), "API does not support this broker method: `order_limit`.")
+ self.assertEqual(str(e), "API does not support this broker method: `order_crypto_limit`.")
try:
api.order_option_limit("buy", "A", 5, 7, "call", now(), 8)
@@ -196,6 +198,20 @@ def test_exceptions(self):
except NotImplementedError as e:
self.assertEqual(str(e), "API does not support this broker method: `order_option_limit`.")
+ try:
+ api.buy('A', -1, 0)
+ self.assertTrue(False)
+ except NotImplementedError as e:
+ self.assertEqual(str(e), "API does not support this broker method: `order_stock_limit`.")
+
+ try:
+ api.sell('A', -1, 0)
+ self.assertTrue(False)
+ except NotImplementedError as e:
+ self.assertEqual(str(e), "API does not support this broker method: `order_stock_limit`.")
+
+
+
def test_base_cases(self):
api = API()
@@ -206,15 +222,13 @@ def test_base_cases(self):
self.assertEqual(api.fetch_crypto_positions(), [])
api.update_option_positions([])
self.assertEqual(api.fetch_order_queue(), [])
- self.assertTrue(api.buy('A', -1) is None)
- self.assertTrue(api.buy_option("A", -1) is None)
def test_run_once(self):
api = API()
fn = lambda x: x + 1
wrapper = API._run_once(fn)
- self.assertEqual(wrapper(api)(5), 6)
- self.assertTrue(wrapper(api) is None)
+ self.assertEqual(wrapper(5), 6)
+ self.assertTrue(wrapper(5) is None)
def test_timestamp(self):
api = API()
diff --git a/tests/test_api_paper.py b/tests/test_api_paper.py
index ec640eba..6b2c9671 100644
--- a/tests/test_api_paper.py
+++ b/tests/test_api_paper.py
@@ -12,7 +12,6 @@
class TestPaperBroker(unittest.TestCase):
def test_account(self):
dummy = PaperBroker()
- dummy.streamer = DummyStreamer()
d = dummy.fetch_account()
self.assertEqual(d["equity"], 1000000.0)
self.assertEqual(d["cash"], 1000000.0)
@@ -22,7 +21,6 @@ def test_account(self):
def test_dummy_account(self):
directory = pathlib.Path(__file__).parent.resolve()
dummy = PaperBroker(str(directory) + "/../dummy_account.yaml")
- dummy.streamer = DummyStreamer()
stocks = dummy.fetch_stock_positions()
self.assertEqual(len(stocks), 2)
self.assertEqual(stocks[0]["symbol"], "A")
@@ -37,7 +35,6 @@ def test_dummy_account(self):
def test_buy_order_limit(self):
dummy = PaperBroker()
- dummy.streamer = DummyStreamer()
interval = {"A": {"interval": Interval.MIN_1, "aggregations": []}}
dummy.setup(interval)
order = dummy.order_stock_limit("buy", "A", 5, 50000)
@@ -56,10 +53,9 @@ def test_buy_order_limit(self):
def test_buy(self):
dummy = PaperBroker()
- dummy.streamer = DummyStreamer()
interval = {"A": {"interval": Interval.MIN_1, "aggregations": []}}
dummy.setup(interval)
- order = dummy.buy("A", 5)
+ order = dummy.buy("A", 5, 1e5)
self.assertEqual(order["type"], "STOCK")
self.assertEqual(order["id"], 0)
self.assertEqual(order["symbol"], "A")
@@ -76,7 +72,6 @@ def test_buy(self):
def test_sell_order_limit(self):
directory = pathlib.Path(__file__).parent.resolve()
dummy = PaperBroker(str(directory) + "/../dummy_account.yaml")
- dummy.streamer = DummyStreamer()
interval = {"A": {"interval": Interval.MIN_1, "aggregations": []}}
dummy.setup(interval)
order = dummy.order_stock_limit("sell", "A", 2, 50000)
@@ -96,7 +91,6 @@ def test_sell_order_limit(self):
def test_sell(self):
directory = pathlib.Path(__file__).parent.resolve()
dummy = PaperBroker(str(directory) + "/../dummy_account.yaml")
- dummy.streamer = DummyStreamer()
interval = {"A": {"interval": Interval.MIN_1, "aggregations": []}}
dummy.setup(interval)
order = dummy.sell("A", 2)
@@ -115,19 +109,18 @@ def test_sell(self):
def test_order_option_limit(self):
dummy = PaperBroker()
- dummy.streamer = DummyStreamer()
interval = {"A": {"interval": Interval.MIN_1, "aggregations": []}}
dummy.setup(interval)
- exp_date = dt.datetime.now() + dt.timedelta(hours=5)
+ exp_date = dt.datetime(2021, 11, 14) + dt.timedelta(hours=5)
order = dummy.order_option_limit(
"buy", "A", 5, 50000, "OPTION", exp_date, 50001
)
self.assertEqual(order["type"], "OPTION")
self.assertEqual(order["id"], 0)
- self.assertEqual(order["symbol"], "A 211106P50001000")
+ self.assertEqual(order["symbol"], "A 211114P50001000")
status = dummy.fetch_option_order_status(order["id"])
- self.assertEqual(status["symbol"], "A 211106P50001000")
+ self.assertEqual(status["symbol"], "A 211114P50001000")
self.assertEqual(status["quantity"], 5)
def test_commission(self):
diff --git a/tests/test_api_webull.py b/tests/test_api_webull.py
index 3656cebd..2fe0220d 100644
--- a/tests/test_api_webull.py
+++ b/tests/test_api_webull.py
@@ -58,7 +58,7 @@ def test_main(df):
"@BTC": {"interval": Interval.MIN_1, "aggregations": []},
"SPY": {"interval": Interval.MIN_1, "aggregations": []},
}
- wb.setup(interval, None, test_main)
+ wb.setup(interval, test_main)
wb.main()
@not_gh_action
@@ -72,7 +72,7 @@ def test_main(df):
wb = Webull()
watch = ["SPY"]
interval = {"SPY": {"interval": Interval.MIN_1, "aggregations": []}}
- wb.setup(interval, None, test_main)
+ wb.setup(interval, test_main)
wb.main()
@not_gh_action
diff --git a/tests/test_api_yahoo.py b/tests/test_api_yahoo.py
index f9734e67..46994d1a 100644
--- a/tests/test_api_yahoo.py
+++ b/tests/test_api_yahoo.py
@@ -43,7 +43,7 @@ def test_main(df):
yh = YahooStreamer()
watch = ["SPY", "AAPL", "@BTC"]
- yh.setup(interval, None, test_main)
+ yh.setup(interval, test_main)
yh.main()
def test_main_single(self):
@@ -54,14 +54,14 @@ def test_main(df):
self.assertEqual(df["SPY"].columns[0][0], "SPY")
yh = YahooStreamer()
- yh.setup(interval, None, test_main)
+ yh.setup(interval, test_main)
yh.main()
def test_chain_info(self):
t = PaperTrader()
yh = YahooStreamer()
interval = {"LMND": {"interval": Interval.MIN_1, "aggregations": []}}
- yh.setup(interval, t, None)
+ yh.setup(interval)
info = yh.fetch_chain_info("LMND")
self.assertGreater(len(info["exp_dates"]), 0)
@@ -69,7 +69,7 @@ def test_chain_data(self):
t = PaperTrader()
yh = YahooStreamer()
interval = {"LMND": {"interval": Interval.MIN_1, "aggregations": []}}
- yh.setup(interval, t, None)
+ yh.setup(interval)
dates = yh.fetch_chain_info("LMND")["exp_dates"]
data = yh.fetch_chain_data("LMND", dates[0])
self.assertGreater(len(data), 0)
diff --git a/tests/test_storage.py b/tests/test_storage.py
index bd2a7e20..4347be7a 100644
--- a/tests/test_storage.py
+++ b/tests/test_storage.py
@@ -11,7 +11,7 @@
class TestBaseStorage(unittest.TestCase):
def test_create_storage(self):
storage = BaseStorage()
- self.assertEqual(storage.storage, {})
+ self.assertEqual(storage.storage_price, {})
def test_simple_store(self):
storage = BaseStorage()
@@ -19,8 +19,8 @@ def test_simple_store(self):
storage.store("A", Interval.MIN_1, data.copy(True))
self.assertTrue(not pd.isna(data.iloc[0]["A"]["low"]))
- self.assertEqual(list(storage.storage.keys()), ["A"])
- self.assertEqual(list(storage.storage["A"].keys()), [Interval.MIN_1])
+ self.assertEqual(list(storage.storage_price.keys()), ["A"])
+ self.assertEqual(list(storage.storage_price["A"].keys()), [Interval.MIN_1])
def test_simple_load(self):
storage = BaseStorage()
diff --git a/tests/test_storage_csv.py b/tests/test_storage_csv.py
index 319901f5..156b193f 100644
--- a/tests/test_storage_csv.py
+++ b/tests/test_storage_csv.py
@@ -18,7 +18,7 @@ def setUpClass(self):
def test_create_storage(self):
storage = CSVStorage(self.storage_dir)
- self.assertEqual(storage.storage, {})
+ self.assertEqual(storage.storage_price, {})
def test_simple_store(self):
storage = CSVStorage(self.storage_dir)
@@ -26,8 +26,8 @@ def test_simple_store(self):
storage.store("A", Interval.MIN_1, data.copy(True))
self.assertTrue(not pd.isna(data.iloc[0]["A"]["low"]))
- self.assertTrue("A" in list(storage.storage.keys()))
- self.assertEqual(list(storage.storage["A"].keys()), [Interval.MIN_1])
+ self.assertTrue("A" in list(storage.storage_price.keys()))
+ self.assertEqual(list(storage.storage_price["A"].keys()), [Interval.MIN_1])
def test_saved_load(self):
storage1 = CSVStorage(self.storage_dir)
diff --git a/tests/test_storage_pickle.py b/tests/test_storage_pickle.py
index 808724ef..59684af6 100644
--- a/tests/test_storage_pickle.py
+++ b/tests/test_storage_pickle.py
@@ -18,7 +18,7 @@ def setUpClass(self):
def test_create_storage(self):
storage = PickleStorage(self.storage_dir)
- self.assertEqual(storage.storage, {})
+ self.assertEqual(storage.storage_price, {})
def test_simple_store(self):
storage = PickleStorage(self.storage_dir)
@@ -26,8 +26,8 @@ def test_simple_store(self):
storage.store("A", Interval.MIN_1, data.copy(True))
self.assertTrue(not pd.isna(data.iloc[0]["A"]["low"]))
- self.assertTrue("A" in list(storage.storage.keys()))
- self.assertEqual(list(storage.storage["A"].keys()), [Interval.MIN_1])
+ self.assertTrue("A" in list(storage.storage_price.keys()))
+ self.assertEqual(list(storage.storage_price["A"].keys()), [Interval.MIN_1])
def test_saved_load(self):
storage1 = PickleStorage(self.storage_dir)
diff --git a/tests/test_tester.py b/tests/test_tester.py
index 805f64f2..b0bcc67a 100644
--- a/tests/test_tester.py
+++ b/tests/test_tester.py
@@ -43,7 +43,7 @@ def test_check_aggregation(self):
minutes = list(t.storage.load("A", Interval.MIN_1)["A"]["close"])[-200:]
days_agg = list(
- t.storage.load("A", int(Interval.DAY_1) - 16, no_slice=True)["A"]["close"]
+ t.storage.load("A", int(Interval.DAY_1) - 16)["A"]["close"]
)[-200:]
self.assertListEqual(minutes, days_agg)