Skip to content

Commit

Permalink
⚡️ perf: concurrent fetch all
Browse files Browse the repository at this point in the history
  • Loading branch information
ljnsn committed Apr 28, 2024
1 parent fc4c1e9 commit d5428f1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
26 changes: 24 additions & 2 deletions dsws_client/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""The DSWS client."""

import concurrent.futures
import itertools
import logging
import sys
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(self, username: str, password: str, **kwargs: Any) -> None:
verify=config.ssl_cert or True,
headers={"Content-Type": "application/json"},
)
self._max_concurrency = config.max_concurrency
self._app_id = config.app_id
self._data_source = config.data_source
self._debug = config.debug
Expand Down Expand Up @@ -83,7 +85,7 @@ def fetch_snapshot_data(
return_symbol_names=True,
return_field_names=True,
)
responses = self.fetch_all(request_bundles)
responses = self.fetch_all(request_bundles, threaded=self._max_concurrency > 1)
data_responses = itertools.chain.from_iterable(
response.data_responses for response in responses
)
Expand All @@ -110,7 +112,7 @@ def fetch_timeseries_data( # noqa: PLR0913
return_symbol_names=True,
return_field_names=True,
)
responses = self.fetch_all(request_bundles)
responses = self.fetch_all(request_bundles, threaded=self._max_concurrency > 1)
data_responses = itertools.chain.from_iterable(
response.data_responses for response in responses
)
Expand Down Expand Up @@ -171,11 +173,31 @@ def fetch_bundle(
def fetch_all(
self,
request_bundles: List[List[DSDataRequest]],
*,
threaded: bool = False,
) -> Iterator[DSGetDataBundleResponse]:
"""Fetch as many bundles as needed to get all items."""
if threaded:
yield from self.fetch_all_threaded(request_bundles)
for bundle in request_bundles:
yield self.fetch_bundle(bundle)

def fetch_all_threaded(
self,
request_bundles: List[List[DSDataRequest]],
) -> Iterator[DSGetDataBundleResponse]:
"""Fetch as many bundles as needed to get all items (concurrently)."""
with concurrent.futures.ThreadPoolExecutor(
max_workers=self._max_concurrency
) as executor:
logger.debug("fetching bundles in parallel")
futures = [
executor.submit(self.fetch_bundle, bundle) for bundle in request_bundles
]
return (
future.result() for future in concurrent.futures.as_completed(futures)
)

def fetch_token(self, **kwargs: object) -> Token:
"""
Fetch a new token.
Expand Down
1 change: 1 addition & 0 deletions dsws_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class DSWSConfig:
ssl_cert: Optional[str] = None
app_id: str = f"dsws-client-{__version__}"
data_source: Optional[str] = None
max_concurrency: int = 1
debug: bool = attrs.field(default=False, converter=attrs.converters.to_bool)

def __init__(self, **kwargs: Any) -> None:
Expand Down

0 comments on commit d5428f1

Please sign in to comment.