diff --git a/pytd/writer.py b/pytd/writer.py index ad68c80..e9614eb 100644 --- a/pytd/writer.py +++ b/pytd/writer.py @@ -461,25 +461,36 @@ def write_dataframe( # chunk number of records should not exceed 200 to avoid OSError _chunk_record_size = max(chunk_record_size, num_rows // 200) try: - range_func = ( - tqdm( - range(0, num_rows, _chunk_record_size), - desc="Chunking into msgpack", - ) - if show_progress - else range(0, num_rows, _chunk_record_size) - ) - for start in range_func: - records = dataframe.iloc[ - start : start + _chunk_record_size - ].to_dict(orient="records") - fp = tempfile.NamedTemporaryFile( - suffix=".msgpack.gz", delete=False + with ThreadPoolExecutor(max_workers=max_workers) as executor: + range_func = ( + tqdm( + range(0, num_rows, _chunk_record_size), + desc="Chunking into msgpack", + ) + if show_progress + else range(0, num_rows, _chunk_record_size) ) - fp = self._write_msgpack_stream(records, fp) - fps.append(fp) - stack.callback(os.unlink, fp.name) - stack.callback(fp.close) + + futures = [] + for start in range_func: + records = dataframe.iloc[ + start : start + _chunk_record_size + ].to_dict(orient="records") + fp = tempfile.NamedTemporaryFile( + suffix=".msgpack.gz", delete=False + ) + futures.append( + ( + start, + executor.submit( + self._write_msgpack_stream, records, fp + ), + ) + ) + stack.callback(os.unlink, fp.name) + stack.callback(fp.close) + for start, future in sorted(futures): + fps.append(future.result()) except OSError as e: raise RuntimeError( "failed to create a temporary file. "