Skip to content

Commit

Permalink
Add show_progress flag to BulkImportWriter (#141)
Browse files Browse the repository at this point in the history
* Add show_progress flag to bulk writer

* Return comment

* Run formatters, bump version, add whitespace in error message

* address PR comments

* refactor
  • Loading branch information
DavidLandup0 authored Dec 6, 2024
1 parent 7357c4e commit 74abe50
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pytd/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
if apikey is None:
raise ValueError(
"either argument 'apikey' or environment variable"
"'TD_API_KEY' should be set"
" 'TD_API_KEY' should be set"
)
if endpoint is None:
endpoint = os.getenv("TD_API_SERVER", "https://api.treasuredata.com")
Expand Down
6 changes: 3 additions & 3 deletions pytd/pandas_td/ipython.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""IPython Magics
IPython magics to access to Treasure Data. Load the magics first of all:
IPython magics to access to Treasure Data. Load the magics first of all:
.. code-block:: ipython
.. code-block:: ipython
In [1]: %load_ext pytd.pandas_td.ipython
In [1]: %load_ext pytd.pandas_td.ipython
"""

import argparse
Expand Down
2 changes: 1 addition & 1 deletion pytd/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def fetch_td_spark_context(
if apikey is None:
raise ValueError(
"either argument 'apikey' or environment variable"
"'TD_API_KEY' should be set"
" 'TD_API_KEY' should be set"
)
if endpoint is None:
endpoint = os.getenv("TD_API_SERVER", "https://api.treasuredata.com")
Expand Down
68 changes: 58 additions & 10 deletions pytd/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
import pandas as pd
from tdclient.util import normalized_msgpack
from tqdm import tqdm

from .spark import fetch_td_spark_context

Expand Down Expand Up @@ -321,6 +322,7 @@ def write_dataframe(
keep_list=False,
max_workers=5,
chunk_record_size=10_000,
show_progress=False,
):
"""Write a given DataFrame to a Treasure Data table.
Expand Down Expand Up @@ -367,9 +369,14 @@ def write_dataframe(
will be converted array<T> on Treasure Data table.
Each type of element of list will be converted by
``numpy.array(your_list).tolist()``.
If True, ``fmt`` argument will be overwritten with ``msgpack``.
show_progress : boolean, default: False
If this argument is True, shows a TQDM progress bar
for chunking data into msgpack format and uploading before
performing a bulk import.
Examples
---------
Expand Down Expand Up @@ -456,7 +463,15 @@ def write_dataframe(
try:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for start in range(0, num_rows, _chunk_record_size):
chunk_range = (
tqdm(
range(0, num_rows, _chunk_record_size),
desc="Chunking data",
)
if show_progress
else range(0, num_rows, _chunk_record_size)
)
for start in chunk_range:
records = dataframe.iloc[
start : start + _chunk_record_size
].to_dict(orient="records")
Expand All @@ -473,7 +488,12 @@ def write_dataframe(
)
stack.callback(os.unlink, fp.name)
stack.callback(fp.close)
for start, future in sorted(futures):
resolve_range = (
tqdm(sorted(futures), desc="Resolving futures")
if show_progress
else sorted(futures)
)
for start, future in resolve_range:
fps.append(future.result())
except OSError as e:
raise RuntimeError(
Expand All @@ -485,10 +505,25 @@ def write_dataframe(
f"unsupported format '{fmt}' for bulk import. "
"should be 'csv' or 'msgpack'"
)
self._bulk_import(table, fps, if_exists, fmt, max_workers=max_workers)
self._bulk_import(
table,
fps,
if_exists,
fmt,
max_workers=max_workers,
show_progress=show_progress,
)
stack.close()

def _bulk_import(self, table, file_likes, if_exists, fmt="csv", max_workers=5):
def _bulk_import(
self,
table,
file_likes,
if_exists,
fmt="csv",
max_workers=5,
show_progress=False,
):
"""Write a specified CSV file to a Treasure Data table.
This method uploads the file to Treasure Data via bulk import API.
Expand All @@ -515,6 +550,10 @@ def _bulk_import(self, table, file_likes, if_exists, fmt="csv", max_workers=5):
max_workers : int, optional, default: 5
The maximum number of threads that can be used to execute the given calls.
This is used only when ``fmt`` is ``msgpack``.
show_progress : boolean, default: False
If this argument is True, shows a TQDM progress bar
for the upload process performed on multiple threads.
"""
params = None
if table.exists:
Expand Down Expand Up @@ -544,16 +583,25 @@ def _bulk_import(self, table, file_likes, if_exists, fmt="csv", max_workers=5):
logger.info(f"uploading data converted into a {fmt} file")
if fmt == "msgpack":
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for i, fp in enumerate(file_likes):
fsize = fp.tell()
fp.seek(0)
executor.submit(
bulk_import.upload_part,
f"part-{i}",
fp,
fsize,
futures.append(
executor.submit(
bulk_import.upload_part,
f"part-{i}",
fp,
fsize,
)
)
logger.debug(f"to upload {fp.name} to TD. File size: {fsize}B")
if show_progress:
for _ in tqdm(futures, desc="Uploading parts"):
_.result()
else:
for future in futures:
future.result()
else:
fp = file_likes[0]
bulk_import.upload_file("part", fmt, fp)
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ install_requires =
numpy>1.17.3, <2.0.0
td-client>=1.1.0
pytz>=2018.5
tqdm>=4.60.0

[options.extras_require]
spark =
Expand Down Expand Up @@ -65,7 +66,7 @@ exclude =
doc/conf.py

[isort]
known_third_party = IPython,msgpack,nox,numpy,pandas,pkg_resources,prestodb,pytz,setuptools,tdclient
known_third_party = IPython,msgpack,nox,numpy,pandas,pkg_resources,prestodb,pytz,setuptools,tdclient,tqdm
line_length=88
multi_line_output=3
include_trailing_comma=True

0 comments on commit 74abe50

Please sign in to comment.