From 1432baeecd5fa191b40182aa9f14a511110456d9 Mon Sep 17 00:00:00 2001 From: Daniel Nagy Date: Sat, 10 May 2025 23:45:00 +0200 Subject: [PATCH] More pythonic API by using `asyncio.Queue` For each WebSocket subscription `Queues` are now being filled. A collection task is being started at the beginning of a command to fill these queues. --- pytr/api.py | 44 ++++++++++++++++++++++++++++++++++++++ pytr/details.py | 57 ++++++++++++------------------------------------- 2 files changed, 58 insertions(+), 43 deletions(-) diff --git a/pytr/api.py b/pytr/api.py index ff2ff10..60f35a3 100644 --- a/pytr/api.py +++ b/pytr/api.py @@ -66,6 +66,7 @@ class TradeRepublicApi: _subscription_id_counter = 1 _previous_responses: Dict[str, str] = {} subscriptions: Dict[str, Dict[str, Any]] = {} + subscriptions_futures: Dict[str, asyncio.Queue] = {} _credentials_file = CREDENTIALS_FILE _cookies_file = COOKIES_FILE @@ -322,6 +323,16 @@ async def subscribe(self, payload): await ws.send(f"sub {subscription_id} {json.dumps(payload_with_token)}") return subscription_id + async def subscribe2(self, **kwargs): + subscription_id = await self._next_subscription_id() + ws = await self._get_ws() + fut: asyncio.Queue = asyncio.Queue() + self.subscriptions_futures[subscription_id] = fut + payload = json.dumps(kwargs) + self.log.debug(f"Subscribing: 'sub {subscription_id} {payload}'") + await ws.send(f"sub {subscription_id} {payload}") + return fut + async def unsubscribe(self, subscription_id): ws = await self._get_ws() @@ -372,6 +383,24 @@ async def recv(self): payload = json.loads(payload_str) if payload_str else {} raise TradeRepublicError(subscription_id, subscription, payload) + async def recv2(self): + ws = await self._get_ws() + async for response in ws: + self.log.debug(f"Received message: {response!r}") + + subscription_id = response[: response.find(" ")] + code = response[response.find(" ") + 1 : response.find(" ") + 2] + payload_str = response[response.find(" ") + 2 :].lstrip() + + if subscription_id not in self.subscriptions_futures: + if code != "C": + self.log.debug(f"No active subscription for id {subscription_id}, dropping message") + continue + queue = self.subscriptions_futures[subscription_id] + match code: + case "A": + queue.put_nowait(json.loads(payload_str)) + def _calculate_delta(self, subscription_id, delta_payload): previous_response = self._previous_responses[subscription_id] i, result = 0, [] @@ -426,12 +455,21 @@ async def portfolio_history(self, timeframe): async def instrument_details(self, isin): return await self.subscribe({"type": "instrument", "id": isin}) + async def instrument_details2(self, isin): + return await self.subscribe2(type="instrument", id=isin) + async def instrument_suitability(self, isin): return await self.subscribe({"type": "instrumentSuitability", "instrumentId": isin}) + async def instrument_suitability2(self, isin): + return await self.subscribe2(type="instrumentSuitability", instrumentId=isin) + async def stock_details(self, isin): return await self.subscribe({"type": "stockDetails", "id": isin}) + async def stock_details2(self, isin): + return await self.subscribe2(type="stockDetails", id=isin) + async def add_watchlist(self, isin): return await self.subscribe({"type": "addToWatchlist", "instrumentId": isin}) @@ -444,6 +482,9 @@ async def ticker(self, isin, exchange="LSX"): async def performance(self, isin, exchange="LSX"): return await self.subscribe({"type": "performance", "id": f"{isin}.{exchange}"}) + async def performance2(self, isin, exchange="LSX"): + return await self.subscribe2(type="performance", id=f"{isin}.{exchange}") + async def performance_history(self, isin, timeframe, exchange="LSX", resolution=None): parameters = { "type": "aggregateHistory", @@ -726,6 +767,9 @@ async def cancel_price_alarm(self, price_alarm_id): async def news(self, isin): return await self.subscribe({"type": "neonNews", "isin": isin}) + async def news2(self, isin): + return await self.subscribe2(type="neonNews", isin=isin) + async def news_subscriptions(self): return await self.subscribe({"type": "newsSubscriptions"}) diff --git a/pytr/details.py b/pytr/details.py index 8a52a58..6604215 100644 --- a/pytr/details.py +++ b/pytr/details.py @@ -1,8 +1,6 @@ import asyncio from datetime import datetime, timedelta -from pytr.utils import preview - class Details: def __init__(self, tr, isin): @@ -10,47 +8,20 @@ def __init__(self, tr, isin): self.isin = isin async def details_loop(self): - recv = 0 - await self.tr.stock_details(self.isin) - await self.tr.news(self.isin) - # await self.tr.subscribe_news(self.isin) - await self.tr.ticker(self.isin, exchange="LSX") - await self.tr.performance(self.isin, exchange="LSX") - await self.tr.instrument_details(self.isin) - await self.tr.instrument_suitability(self.isin) - - # await self.tr.add_watchlist(self.isin) - # await self.tr.remove_watchlist(self.isin) - # await self.tr.savings_plan_parameters(self.isin) - # await self.tr.unsubscribe_news(self.isin) - - while True: - _subscription_id, subscription, response = await self.tr.recv() - - if subscription["type"] == "stockDetails": - recv += 1 - self.stockDetails = response - elif subscription["type"] == "neonNews": - recv += 1 - self.neonNews = response - elif subscription["type"] == "ticker": - recv += 1 - self.ticker = response - elif subscription["type"] == "performance": - recv += 1 - self.performance = response - elif subscription["type"] == "instrument": - recv += 1 - self.instrument = response - elif subscription["type"] == "instrumentSuitability": - recv += 1 - self.instrumentSuitability = response - print("instrumentSuitability:", response) - else: - print(f"unmatched subscription of type '{subscription['type']}':\n{preview(response, num_lines=30)}") - - if recv == 6: - return + asyncio.create_task(self.tr.recv2()) + ( + self.stockDetails, + self.neonNews, + self.performance, + self.instrument, + self.instrumentSuitability, + ) = await asyncio.gather( + (await self.tr.stock_details2(self.isin)).get(), + (await self.tr.news2(self.isin)).get(), + (await self.tr.performance2(self.isin, exchange="LSX")).get(), + (await self.tr.instrument_details2(self.isin)).get(), + (await self.tr.instrument_suitability2(self.isin)).get(), + ) def print_instrument(self): print("Name:", self.instrument["name"])