From 74abe50ab71fe8213959a5882cf62614046d23f2 Mon Sep 17 00:00:00 2001 From: David Landup <60978046+DavidLandup0@users.noreply.github.com> Date: Fri, 6 Dec 2024 14:39:11 +0900 Subject: [PATCH] Add `show_progress` flag to `BulkImportWriter` (#141) * Add show_progress flag to bulk writer * Return comment * Run formatters, bump version, add whitespace in error message * address PR comments * refactor --- pytd/client.py | 2 +- pytd/pandas_td/ipython.py | 6 ++-- pytd/spark.py | 2 +- pytd/writer.py | 68 +++++++++++++++++++++++++++++++++------ setup.cfg | 3 +- 5 files changed, 65 insertions(+), 16 deletions(-) diff --git a/pytd/client.py b/pytd/client.py index 382a5fc..1aad656 100644 --- a/pytd/client.py +++ b/pytd/client.py @@ -89,7 +89,7 @@ def __init__( if apikey is None: raise ValueError( "either argument 'apikey' or environment variable" - "'TD_API_KEY' should be set" + " 'TD_API_KEY' should be set" ) if endpoint is None: endpoint = os.getenv("TD_API_SERVER", "https://api.treasuredata.com") diff --git a/pytd/pandas_td/ipython.py b/pytd/pandas_td/ipython.py index 26931fa..5e840c0 100644 --- a/pytd/pandas_td/ipython.py +++ b/pytd/pandas_td/ipython.py @@ -1,10 +1,10 @@ """IPython Magics - IPython magics to access to Treasure Data. Load the magics first of all: +IPython magics to access to Treasure Data. Load the magics first of all: - .. code-block:: ipython +.. code-block:: ipython - In [1]: %load_ext pytd.pandas_td.ipython + In [1]: %load_ext pytd.pandas_td.ipython """ import argparse diff --git a/pytd/spark.py b/pytd/spark.py index a84bc96..ccf4e3b 100644 --- a/pytd/spark.py +++ b/pytd/spark.py @@ -96,7 +96,7 @@ def fetch_td_spark_context( if apikey is None: raise ValueError( "either argument 'apikey' or environment variable" - "'TD_API_KEY' should be set" + " 'TD_API_KEY' should be set" ) if endpoint is None: endpoint = os.getenv("TD_API_SERVER", "https://api.treasuredata.com") diff --git a/pytd/writer.py b/pytd/writer.py index a53508a..989dca5 100644 --- a/pytd/writer.py +++ b/pytd/writer.py @@ -12,6 +12,7 @@ import numpy as np import pandas as pd from tdclient.util import normalized_msgpack +from tqdm import tqdm from .spark import fetch_td_spark_context @@ -321,6 +322,7 @@ def write_dataframe( keep_list=False, max_workers=5, chunk_record_size=10_000, + show_progress=False, ): """Write a given DataFrame to a Treasure Data table. @@ -367,9 +369,14 @@ def write_dataframe( will be converted array on Treasure Data table. Each type of element of list will be converted by ``numpy.array(your_list).tolist()``. - If True, ``fmt`` argument will be overwritten with ``msgpack``. + + show_progress : boolean, default: False + If this argument is True, shows a TQDM progress bar + for chunking data into msgpack format and uploading before + performing a bulk import. + Examples --------- @@ -456,7 +463,15 @@ def write_dataframe( try: with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] - for start in range(0, num_rows, _chunk_record_size): + chunk_range = ( + tqdm( + range(0, num_rows, _chunk_record_size), + desc="Chunking data", + ) + if show_progress + else range(0, num_rows, _chunk_record_size) + ) + for start in chunk_range: records = dataframe.iloc[ start : start + _chunk_record_size ].to_dict(orient="records") @@ -473,7 +488,12 @@ def write_dataframe( ) stack.callback(os.unlink, fp.name) stack.callback(fp.close) - for start, future in sorted(futures): + resolve_range = ( + tqdm(sorted(futures), desc="Resolving futures") + if show_progress + else sorted(futures) + ) + for start, future in resolve_range: fps.append(future.result()) except OSError as e: raise RuntimeError( @@ -485,10 +505,25 @@ def write_dataframe( f"unsupported format '{fmt}' for bulk import. " "should be 'csv' or 'msgpack'" ) - self._bulk_import(table, fps, if_exists, fmt, max_workers=max_workers) + self._bulk_import( + table, + fps, + if_exists, + fmt, + max_workers=max_workers, + show_progress=show_progress, + ) stack.close() - def _bulk_import(self, table, file_likes, if_exists, fmt="csv", max_workers=5): + def _bulk_import( + self, + table, + file_likes, + if_exists, + fmt="csv", + max_workers=5, + show_progress=False, + ): """Write a specified CSV file to a Treasure Data table. This method uploads the file to Treasure Data via bulk import API. @@ -515,6 +550,10 @@ def _bulk_import(self, table, file_likes, if_exists, fmt="csv", max_workers=5): max_workers : int, optional, default: 5 The maximum number of threads that can be used to execute the given calls. This is used only when ``fmt`` is ``msgpack``. + + show_progress : boolean, default: False + If this argument is True, shows a TQDM progress bar + for the upload process performed on multiple threads. """ params = None if table.exists: @@ -544,16 +583,25 @@ def _bulk_import(self, table, file_likes, if_exists, fmt="csv", max_workers=5): logger.info(f"uploading data converted into a {fmt} file") if fmt == "msgpack": with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] for i, fp in enumerate(file_likes): fsize = fp.tell() fp.seek(0) - executor.submit( - bulk_import.upload_part, - f"part-{i}", - fp, - fsize, + futures.append( + executor.submit( + bulk_import.upload_part, + f"part-{i}", + fp, + fsize, + ) ) logger.debug(f"to upload {fp.name} to TD. File size: {fsize}B") + if show_progress: + for _ in tqdm(futures, desc="Uploading parts"): + _.result() + else: + for future in futures: + future.result() else: fp = file_likes[0] bulk_import.upload_file("part", fmt, fp) diff --git a/setup.cfg b/setup.cfg index 7caab83..db117aa 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,6 +34,7 @@ install_requires = numpy>1.17.3, <2.0.0 td-client>=1.1.0 pytz>=2018.5 + tqdm>=4.60.0 [options.extras_require] spark = @@ -65,7 +66,7 @@ exclude = doc/conf.py [isort] -known_third_party = IPython,msgpack,nox,numpy,pandas,pkg_resources,prestodb,pytz,setuptools,tdclient +known_third_party = IPython,msgpack,nox,numpy,pandas,pkg_resources,prestodb,pytz,setuptools,tdclient,tqdm line_length=88 multi_line_output=3 include_trailing_comma=True