From 5b78706ba90484ea75c9f82677941aef8ab40996 Mon Sep 17 00:00:00 2001 From: Aki Ariga Date: Sun, 15 Sep 2024 17:30:20 -0700 Subject: [PATCH] Create temporary file within ThreadPoolExecutor --- pytd/writer.py | 111 ++++++++++++++++++++++--------------------------- 1 file changed, 49 insertions(+), 62 deletions(-) diff --git a/pytd/writer.py b/pytd/writer.py index d490eb4..4c48cef 100644 --- a/pytd/writer.py +++ b/pytd/writer.py @@ -439,46 +439,9 @@ def write_dataframe( fmt = "msgpack" _cast_dtypes(dataframe, keep_list=keep_list) + self._bulk_import(table, dataframe, if_exists, fmt, max_workers=max_workers, chunk_record_size=chunk_record_size) - with ExitStack() as stack: - fps = [] - if fmt == "csv": - fp = tempfile.NamedTemporaryFile(suffix=".csv", delete=False) - stack.callback(os.unlink, fp.name) - stack.callback(fp.close) - dataframe.to_csv(fp.name) - fps.append(fp) - elif fmt == "msgpack": - _replace_pd_na(dataframe) - num_rows = len(dataframe) - # chunk number of records should not exceed 200 to avoid OSError - _chunk_record_size = max(chunk_record_size, num_rows//200) - try: - for start in range(0, num_rows, _chunk_record_size): - records = dataframe.iloc[ - start : start + _chunk_record_size - ].to_dict(orient="records") - fp = tempfile.NamedTemporaryFile( - suffix=".msgpack.gz", delete=False - ) - fp = self._write_msgpack_stream(records, fp) - fps.append(fp) - stack.callback(os.unlink, fp.name) - stack.callback(fp.close) - except OSError as e: - raise RuntimeError( - "failed to create a temporary file. " - "Larger chunk_record_size may mitigate the issue." - ) from e - else: - raise ValueError( - f"unsupported format '{fmt}' for bulk import. " - "should be 'csv' or 'msgpack'" - ) - self._bulk_import(table, fps, if_exists, fmt, max_workers=max_workers) - stack.close() - - def _bulk_import(self, table, file_likes, if_exists, fmt="csv", max_workers=5): + def _bulk_import(self, table, dataframe, if_exists, fmt="csv", max_workers=5, chunk_record_size=10_000): """Write a specified CSV file to a Treasure Data table. This method uploads the file to Treasure Data via bulk import API. @@ -488,8 +451,7 @@ def _bulk_import(self, table, file_likes, if_exists, fmt="csv", max_workers=5): table : :class:`pytd.table.Table` Target table. - file_likes : List of file like objects - Data in this file will be loaded to a target table. + dataframe : DataFrame to be uploaded if_exists : str, {'error', 'overwrite', 'append', 'ignore'} What happens when a target table already exists. @@ -505,6 +467,10 @@ def _bulk_import(self, table, file_likes, if_exists, fmt="csv", max_workers=5): max_workers : int, optional, default: 5 The maximum number of threads that can be used to execute the given calls. This is used only when ``fmt`` is ``msgpack``. + + chunk_record_size : int, optional, default: 10_000 + The number of records to be written in a single file. This is used only when + ``fmt`` is ``msgpack``. """ params = None if table.exists: @@ -530,27 +496,48 @@ def _bulk_import(self, table, file_likes, if_exists, fmt="csv", max_workers=5): session_name, table.database, table.table, params=params ) s_time = time.time() - try: - logger.info(f"uploading data converted into a {fmt} file") - if fmt == "msgpack": - with ThreadPoolExecutor(max_workers=max_workers) as executor: - for i, fp in enumerate(file_likes): - fsize = fp.tell() - fp.seek(0) - executor.submit( - bulk_import.upload_part, - f"part-{i}", - fp, - fsize, - ) - logger.debug(f"to upload {fp.name} to TD. File size: {fsize}B") - else: - fp = file_likes[0] - bulk_import.upload_file("part", fmt, fp) - bulk_import.freeze() - except Exception as e: - bulk_import.delete() - raise RuntimeError(f"failed to upload file: {e}") + with ExitStack() as stack: + try: + logger.info(f"uploading data converted into a {fmt} file") + if fmt == "csv": + fp = stack.enter_context(tempfile.NamedTemporaryFile(suffix=".csv", delete=False)) + dataframe.to_csv(fp.name) + bulk_import.upload_file("part", fmt, fp) + os.unlink(fp.name) + fp.close() + elif fmt == "msgpack": + _replace_pd_na(dataframe) + num_rows = len(dataframe) + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + for i, start in enumerate(range(0, num_rows, chunk_record_size)): + records = dataframe.iloc[ + start : start + chunk_record_size + ].to_dict(orient="records") + fp = stack.enter_context(tempfile.NamedTemporaryFile( + suffix=".msgpack.gz", delete=False + )) + fp = self._write_msgpack_stream(records, fp) + fsize = fp.tell() + fp.seek(0) + executor.submit( + bulk_import.upload_part, + f"part-{i}", + fp, + fsize, + ) + logger.debug(f"to upload {fp.name} to TD. File size: {fsize}B") + os.unlink(fp.name) + fp.close() + else: + raise ValueError( + f"unsupported format '{fmt}' for bulk import. " + "should be 'csv' or 'msgpack'" + ) + bulk_import.freeze() + except Exception as e: + bulk_import.delete() + raise RuntimeError(f"failed to upload file: {e}") logger.debug(f"uploaded data in {time.time() - s_time:.2f} sec")